#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (C) 2013-2014 Étienne Loks # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # See the file COPYING for details. """ # Usage exemple (extracted from simulabio application) class HarvestPlotImporter(Importer): LINE_FORMAT = [ ImportFormater('name', Importer.get_unicode_formater(100)), ImportFormater('plot_group_number', Importer.get_unicode_formater(3), required=False), ImportFormater('geographical_area', unicode, required=False), ImportFormater('soil_type', Importer.choices_check(SOIL_TYPE), required=False), ImportFormater('cow_access', Importer.boolean_formater), ImportFormater('area', Importer.float_formater), ImportFormater('remark', unicode, required=False), ImportFormater('diagnostic', Importer.boolean_formater), ImportFormater('project', Importer.boolean_formater), ImportFormater('harvest_n2', 'harvest_formater', required=False), ImportFormater('harvest_n1', 'harvest_formater', required=False), ImportFormater('harvest', 'harvest_formater'), ImportFormater('harvest_setting', 'harvest_formater', through=HarvestTransition, through_key='plot', through_dict={'year':1}, through_unicity_keys=['plot', 'year'], required=False), ImportFormater('harvest_setting', 'harvest_formater', through=HarvestTransition, through_key='plot', through_dict={'year':2}, through_unicity_keys=['plot', 'year'], required=False), ImportFormater('harvest_setting', 'harvest_formater', through=HarvestTransition, through_key='plot', through_dict={'year':3}, through_unicity_keys=['plot', 'year'], required=False), ImportFormater('harvest_setting', 'harvest_formater', through=HarvestTransition, through_key='plot', through_dict={'year':4}, through_unicity_keys=['plot', 'year'], required=False), ImportFormater('harvest_setting', 'harvest_formater', through=HarvestTransition, through_key='plot', through_dict={'year':5}, through_unicity_keys=['plot', 'year'], required=False), ImportFormater('harvest_setting', 'harvest_formater', through=HarvestTransition, through_key='plot', through_dict={'year':6}, through_unicity_keys=['plot', 'year'], required=False), ] OBJECT_CLS = HarvestPlots UNICITY_KEYS = [] def __init__(self, study, skip_first_line=None): # get the reference header dct = {'separator':settings.CSV_DELIMITER} dct['data'] = Harvest.objects.filter(available=True).all() reference_file = render_to_string('simulabio/files/parcelles_ref.csv', dct) reference_header = unicode_csv_reader( [reference_file.split('\n')[0]]).next() super(HarvestPlotImporter, self).__init__( skip_first_line=skip_first_line, reference_header=reference_header) self.study = study self.default_vals = {'study':self.study} def harvest_formater(self, value): value = value.strip() if not value: return try: harvest = Harvest.objects.get(name__iexact=value) except ObjectDoesNotExist: raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { 'value':value, 'values':u", ".join([val.name for val in Harvest.objects.filter(available=True)]) }) hs, created = HarvestSettings.objects.get_or_create(study=self.study, harvest=harvest) if created: self.message = _(u"\"%(harvest)s\" has been added in your settings " u"don't forget to fill yields for this harvest.") \ % {'harvest':harvest.name} return hs class HarvestPlotsImportForm(forms.Form): csv_file = forms.FileField(label=_(u"Plot list file (CSV)")) def save(self, study): csv_file = self.cleaned_data['csv_file'] importer = models.HarvestPlotImporter(study, skip_first_line=True) # some softwares (at least Gnumeric) convert CSV file to utf-8 no matter # what the CSV source encoding is encodings = [settings.ENCODING, 'utf-8'] for encoding in encodings: try: importer.importation(unicode_csv_reader( [line.decode(encoding) for line in csv_file.readlines()])) except ImporterError, e: if e.type == ImporterError.HEADER and encoding != encodings[-1]: csv_file.seek(0) continue return 0, [[0, 0, e.msg]], [] except UnicodeDecodeError, e: return 0, [[0, 0, Importer.ERRORS['header_check']]], [] break return importer.number_imported, importer.errors, importer.messages """ import copy, csv, datetime, logging, sys from tempfile import NamedTemporaryFile from django.contrib.auth.models import User from django.db import DatabaseError from django.template.defaultfilters import slugify from django.utils.translation import ugettext_lazy as _ from ishtar_common.unicode_csv import UnicodeWriter class ImportFormater(object): def __init__(self, field_name, formater=None, required=True, through=None, through_key=None, through_dict=None, through_unicity_keys=None, duplicate_field=None, regexp=None, regexp_formater_args=[], reverse_for_test=None, comment=""): self.field_name = field_name self.formater = formater self.required = required self.through = through self.through_key = through_key self.through_dict = through_dict self.through_unicity_keys = through_unicity_keys self.duplicate_field = duplicate_field self.regexp = regexp self.regexp_formater_args = regexp_formater_args self.reverse_for_test = reverse_for_test self.comment = comment def __unicode__(self): return self.field_name def report_succes(self, *args): return def report_error(self, *args): return def init(self, vals): try: lst = iter(self.formater) except TypeError: lst = [self.formater] for formater in lst: formater.check(vals) class ImporterError(Exception): STANDARD = 'S' HEADER = 'H' def __init__(self, message, type='S'): self.msg = message self.type = type def __str__(self): return self.msg class Formater(object): def format(self, value): return value def check(self, values): return class UnicodeFormater(Formater): def __init__(self, max_length, clean=False, re_filter=None): self.max_length = max_length self.clean = clean self.re_filter = re_filter def format(self, value): try: value = unicode(value.strip()) if self.re_filter: m = self.re_filter.match(value) if m: value = u"".join(m.groups()) if self.clean: if value.startswith(","): value = value[1:] if value.endswith(","): value = value[:-1] except UnicodeDecodeError: return if len(value) > self.max_length: raise ValueError(_(u"\"%(value)s\" is too long. "\ u"The max length is %(length)d characters." ) % {'value':value, 'length':self.max_length}) return value class BooleanFormater(Formater): def format(self, value): value = value.strip().upper() if value in ('1', 'OUI', 'VRAI', 'YES', 'TRUE'): return True if value in ('', '0', 'NON', 'FAUX', 'NO', 'FALSE'): return False raise ValueError(_(u"\"%(value)s\" not equal to yes or no") % { 'value':value}) class FloatFormater(Formater): def format(self, value): value = value.strip().replace(',', '.') if not value: return try: return float(value) except ValueError: raise ValueError(_(u"\"%(value)s\" is not a float") % { 'value':value}) class StrChoiceFormater(Formater): def __init__(self, choices, strict=False, equiv_dict={}, cli=False): self.choices = choices self.strict = strict self.equiv_dict = copy.deepcopy(equiv_dict) self.cli = cli self.missings = set() for key, value in self.choices: value = unicode(value) if not self.strict: value = slugify(value) if value not in self.equiv_dict: self.equiv_dict[value] = key def prepare(self, value): return unicode(value).strip() def check(self, values): msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ u"Which one is relevant?\n")) for idx, choice in enumerate(self.choices): msgstr += u"%d. %s\n" % (idx+1, choice[1]) msgstr += unicode(_(u"%d. None of the above")) % (idx+2) + u"\n" for value in values: value = self.prepare(value) if value in self.equiv_dict: continue if not self.cli: self.missings.add(value) continue res = None while res not in range(1, len(self.choices)+2): sys.stdout.write(msgstr % value) res = raw_input(">>> ") try: res = int(res) except ValueError: pass res -= 1 if res < len(self.choices): self.equiv_dict[value] = self.choices[res] else: self.equiv_dict[value] = None def format(self, value): value = self.prepare(value) if not self.strict: value = slugify(value) if value in self.equiv_dict: return self.equiv_dict[value] logger = logging.getLogger(__name__) class Importer(object): LINE_FORMAT = [] OBJECT_CLS = None UNICITY_KEYS = [] DEFAULTS = {} ERRORS = { 'header_check':_(u"The given file is not correct. Check the file " u"format. If you use a CSV file: check that column separator " u"and encoding are similar to the ones used by the reference " u"file."), 'too_many_cols':_(u"Too many cols (%(user_col)d) when " u"maximum is %(ref_col)d"), 'no_data':_(u"No data provided"), 'value_required':_(u"Value is required"), 'not_enough_cols':_(u"At least %d columns must be filled"), 'regex_not_match':_(u"The regexp doesn't match.") } def __init__(self, skip_first_line=False, reference_header=None, check_col_num=False, test=False, check_validity=True, history_modifier=None): """ * skip_first_line must be set to True if the data provided has got an header. * a reference_header can be provided to perform a data compliance check. It can be useful to warn about bad parsing. * test doesn't write in the database * check_validity rewrite a CSV file to be compared """ self.message = '' self.skip_first_line = skip_first_line self.reference_header = reference_header self.test = test self.errors = [] # list of (line, col, message) self.messages = [] # list of (line, col, message) self.number_imported = 0 self.check_col_num = check_col_num self.check_validity = check_validity self.line_format = copy.copy(self.LINE_FORMAT) self._initialized = False self._defaults = self.DEFAULTS.copy() self.history_modifier = history_modifier if not self.history_modifier: # get the first admin self.history_modifier = User.objects.filter(is_superuser=True ).order_by('pk')[0] def initialize(self, table): # copy vals in columns vals = [] for idx_line, line in enumerate(table): if (self.skip_first_line and not idx_line): continue for idx_col, val in enumerate(line): if idx_col >= len(self.line_format): break if idx_col >= len(vals): vals.append([]) vals[idx_col].append(val) for idx, formater in enumerate(self.line_format): formater.init(vals[idx]) self._initialized = True def importation(self, table): if not self._initialized: self.initialize(table) if self.check_validity: with NamedTemporaryFile(delete=False) as validity_file: print(validity_file.name) validity_file_writer = UnicodeWriter(validity_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) self._importation(table, validity_file_writer) else: self._importation(table) def _importation(self, table, validity_file=None): table = list(table) if not table or not table[0]: raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) if self.check_col_num and len(table[0]) > len(self.line_format): raise ImporterError(self.ERRORS['too_many_cols'] % { 'user_col':len(table[0]), 'ref_col':len(self.line_format)}) self.errors = [] self.messages = [] self.number_imported = 0 # index of the last required column for idx_last_col, formater in enumerate(reversed(self.line_format)): if formater.required: break else: idx_last_col += 1 # min col number to be filled min_col_number = len(self.line_format) - idx_last_col # check the conformity with the reference header if self.reference_header and \ self.skip_first_line and \ self.reference_header != table[0]: raise ImporterError(self.ERRORS['header_check'], type=ImporterError.HEADER) now = datetime.datetime.now() for idx_line, line in enumerate(table): if (self.skip_first_line and not idx_line): if validity_file: validity_file.writerow(line) continue if not line: if validity_file: validity_file.writerow([]) continue throughs = [] # list of (formater, value) data = {} n = datetime.datetime.now() logger.debug('%s - Processing line %d' % (unicode(n-now), idx_line)) now = n n2 = n c_errors = False c_row = [] for idx_col, val in enumerate(line): if idx_col >= len(self.line_format): break formater = self.line_format[idx_col] if not formater.field_name: if validity_file: c_row.append(val) continue if formater.regexp: # multiline regexp is a mess... val = val.replace('\n', '######???#####') match = formater.regexp.match(val) if not match: if formater.required: self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) elif not val.strip(): c_row.append("") continue c_errors = True val = val.replace('######???#####', '\n') self.errors.append((idx_line+1, idx_col+1, unicode(self.ERRORS['regex_not_match']) + val)) c_row.append("") continue val_group = [v.replace('######???#####', '\n') for v in match.groups()] else: val_group = [val] c_values = [] for idx_v, v in enumerate(val_group): self.message = '' func = formater.formater if type(func) in (list, tuple): func = func[idx_v] if not callable(func) and type(func) in (unicode, str): func = getattr(self, func) value = None try: if formater.regexp_formater_args: args = [] for idx in formater.regexp_formater_args[idx_v]: args.append(val_group[idx]) value = func.format(*args) else: value = func.format(v) except ValueError, e: c_errors = True self.errors.append((idx_line+1, idx_col+1, e.message)) c_values.append(None) continue if self.message: self.messages.append(self.message) c_values.append(value) if value == None: if formater.required: c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) continue field_name = formater.field_name if type(field_name) in (list, tuple): field_name = field_name[idx_v] field_names = [field_name] if formater.duplicate_field: duplicate_field = formater.duplicate_field if type(duplicate_field) in (list, tuple): duplicate_field = duplicate_field[idx_v] field_names += [duplicate_field] if not formater.through: for field_name in field_names: keys = field_name.split('__') current_data = data for idx, key in enumerate(keys): if idx == (len(keys) - 1): # last current_data[key] = value elif key not in current_data: current_data[key] = {} current_data = current_data[key] else: throughs.append((formater, value)) if formater.reverse_for_test: c_row.append(formater.reverse_for_test(**c_values)) else: c_row.append(unicode(c_values)) if validity_file: validity_file.writerow(c_row) if not c_errors and (idx_col + 1) < min_col_number: c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['not_enough_cols'] % min_col_number)) if c_errors: continue n = datetime.datetime.now() logger.debug('* %s - Cols read' % (unicode(n-n2))) n2 = n if self.test: continue # manage unicity of items (mainly for updates) self.number_imported += 1 if self.UNICITY_KEYS: data['defaults'] = {} for k in data.keys(): if k not in self.UNICITY_KEYS \ and k != 'defaults': data['defaults'][k] = data.pop(k) obj, created = self.get_object(self.OBJECT_CLS, data) if not created and 'defaults' in data: for k in data['defaults']: setattr(obj, k, data['defaults'][k]) obj.save() n = datetime.datetime.now() logger.debug('* %s - Item saved' % (unicode(n-n2))) n2 = n for formater, value in throughs: n = datetime.datetime.now() logger.debug('* %s - Processing formater %s' % (unicode(n-n2), formater.field_name)) n2 = n data = {} if formater.through_dict: data = formater.through_dict.copy() if formater.through_key: data[formater.through_key] = obj data[formater.field_name] = value through_cls = formater.through if formater.through_unicity_keys: data['defaults'] = {} for k in data.keys(): if k not in formater.through_unicity_keys \ and k != 'defaults': data['defaults'][k] = data.pop(k) t_obj, created = through_cls.objects.get_or_create(**data) if not created and 'defaults' in data: for k in data['defaults']: setattr(t_obj, k, data['defaults'][k]) t_obj.save() def get_object(self, cls, data, path=[]): m2ms = [] if data and type(data) == dict: for attribute in data.keys(): c_path = path[:] if not data[attribute]: continue field_object, model, direct, m2m = \ cls._meta.get_field_by_name(attribute) if field_object.rel and type(data[attribute]) == dict: c_path.append(attribute) # put history_modifier for every created item data[attribute]['history_modifier'] = \ self.history_modifier data[attribute], created = self.get_object( field_object.rel.to, data[attribute], c_path) if m2m: m2ms.append((attribute, data.pop(attribute))) path = tuple(path) if path in self._defaults: for k in self._defaults[path]: if k not in data or not data[k]: data[k] = self._defaults[path][k] obj, created = cls.objects.get_or_create(**data) for attr, value in m2ms: getattr(obj, attr).add(value) return obj, created return data def get_csv_errors(self): for line, col, error in self.errors: print '"%d","%d","%s"' % (line, col, unicode(error)) @classmethod def choices_check(cls, choices): def function(value): choices_dct = dict(choices) value = value.strip() if not value: return if value not in choices_dct.values(): raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { 'value':value, 'values':u", ".join([val for val in choices_dct.values()]) }) return value return function