#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (C) 2013-2014 Étienne Loks # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # See the file COPYING for details. import copy, csv, datetime, logging, sys from tempfile import NamedTemporaryFile from django.contrib.auth.models import User from django.db import DatabaseError, IntegrityError from django.template.defaultfilters import slugify from django.utils.translation import ugettext_lazy as _ from ishtar_common.unicode_csv import UnicodeWriter NEW_LINE_BREAK = '#####@@@#####' class ImportFormater(object): def __init__(self, field_name, formater=None, required=True, through=None, through_key=None, through_dict=None, through_unicity_keys=None, duplicate_field=None, regexp=None, regexp_formater_args=[], reverse_for_test=None, force_value=None, post_processing=False, concat=False, comment=""): self.field_name = field_name self.formater = formater self.required = required self.through = through self.through_key = through_key self.through_dict = through_dict self.through_unicity_keys = through_unicity_keys self.duplicate_field = duplicate_field self.regexp = regexp self.regexp_formater_args = regexp_formater_args self.reverse_for_test = reverse_for_test # write this value even if a value exists self.force_value = force_value # post process after import self.post_processing = post_processing # concatenate with existing value self.concat = concat self.comment = comment def __unicode__(self): return self.field_name def report_succes(self, *args): return def report_error(self, *args): return def init(self, vals): try: lst = iter(self.formater) except TypeError: lst = [self.formater] for formater in lst: if formater: formater.check(vals) def post_process(self, obj, context, value, owner=None): raise NotImplemented() class ImporterError(Exception): STANDARD = 'S' HEADER = 'H' def __init__(self, message, type='S'): self.msg = message self.type = type def __str__(self): return self.msg class Formater(object): def format(self, value): return value def check(self, values): return class UnicodeFormater(Formater): def __init__(self, max_length, clean=False, re_filter=None): self.max_length = max_length self.clean = clean self.re_filter = re_filter def format(self, value): try: value = unicode(value.strip()) if self.re_filter: m = self.re_filter.match(value) if m: value = u"".join(m.groups()) if self.clean: if value.startswith(","): value = value[1:] if value.endswith(","): value = value[:-1] value = value.replace(", , ", ", ") except UnicodeDecodeError: return if len(value) > self.max_length: raise ValueError(_(u"\"%(value)s\" is too long. "\ u"The max length is %(length)d characters." ) % {'value':value, 'length':self.max_length}) return value class BooleanFormater(Formater): def format(self, value): value = value.strip().upper() if value in ('1', 'OUI', 'VRAI', 'YES', 'TRUE'): return True if value in ('', '0', 'NON', 'FAUX', 'NO', 'FALSE'): return False raise ValueError(_(u"\"%(value)s\" not equal to yes or no") % { 'value':value}) class FloatFormater(Formater): def format(self, value): value = value.strip().replace(',', '.') if not value: return try: return float(value) except ValueError: raise ValueError(_(u"\"%(value)s\" is not a float") % { 'value':value}) class YearFormater(Formater): def format(self, value): value = value.strip() if not value: return try: value = int(value) assert value > 0 and value < (datetime.date.today().year + 30) except (ValueError, AssertionError): raise ValueError(_(u"\"%(value)s\" is not a valid date") % { 'value':value}) class YearNoFuturFormater(Formater): def format(self, value): value = value.strip() if not value: return try: value = int(value) assert value > 0 and value < (datetime.date.today().year) except (ValueError, AssertionError): raise ValueError(_(u"\"%(value)s\" is not a valid date") % { 'value':value}) class IntegerFormater(Formater): def format(self, value): value = value.strip() if not value: return try: return int(value) except ValueError: raise ValueError(_(u"\"%(value)s\" is not an integer") % { 'value':value}) class StrChoiceFormater(Formater): def __init__(self, choices, strict=False, equiv_dict={}, model=None, cli=False): self.choices = list(choices) self.strict = strict self.equiv_dict = copy.deepcopy(equiv_dict) self.cli = cli self.model = model self.missings = set() for key, value in self.choices: value = unicode(value) if not self.strict: value = slugify(value) if value not in self.equiv_dict: v = key if model and v: v = model.objects.get(pk=v) self.equiv_dict[value] = v def prepare(self, value): return unicode(value).strip() def check(self, values): msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ u"Which one is relevant?\n")) for idx, choice in enumerate(self.choices): msgstr += u"%d. %s\n" % (idx+1, choice[1]) msgstr += unicode(_(u"%d. None of the above")) % (idx+2) + u"\n" for value in values: value = self.prepare(value) if value in self.equiv_dict: continue if not self.cli: self.missings.add(value) continue res = None while res not in range(1, len(self.choices)+2): sys.stdout.write(msgstr % value) res = raw_input(">>> ") try: res = int(res) except ValueError: pass res -= 1 if res < len(self.choices): v = self.choices[res][0] if self.model and v: v = self.model.objects.get(pk=v) self.equiv_dict[value] = v else: self.equiv_dict[value] = None def format(self, value): value = self.prepare(value) if not self.strict: value = slugify(value) if value in self.equiv_dict: return self.equiv_dict[value] class DateFormater(Formater): def __init__(self, date_format="%d/%m/%Y"): self.date_format = date_format def format(self, value): value = value.strip() try: return datetime.datetime.strptime(value, self.date_format).date() except: raise ValueError(_(u"\"%(value)s\" is not a valid date") % { 'value':value}) logger = logging.getLogger(__name__) class Importer(object): LINE_FORMAT = [] OBJECT_CLS = None IMPORTED_LINE_FIELD = None UNICITY_KEYS = [] DEFAULTS = {} ERRORS = { 'header_check':_(u"The given file is not correct. Check the file " u"format. If you use a CSV file: check that column separator " u"and encoding are similar to the ones used by the reference " u"file."), 'too_many_cols':_(u"Too many cols (%(user_col)d) when " u"maximum is %(ref_col)d"), 'no_data':_(u"No data provided"), 'value_required':_(u"Value is required"), 'not_enough_cols':_(u"At least %d columns must be filled"), 'regex_not_match':_(u"The regexp doesn't match.") } def __init__(self, skip_first_line=False, reference_header=None, check_col_num=False, test=False, check_validity=True, history_modifier=None): """ * skip_first_line must be set to True if the data provided has got an header. * a reference_header can be provided to perform a data compliance check. It can be useful to warn about bad parsing. * test doesn't write in the database * check_validity rewrite a CSV file to be compared """ self.message = '' self.skip_first_line = skip_first_line self.reference_header = reference_header self.test = test self.errors = [] # list of (line, col, message) self.messages = [] # list of (line, col, message) self.number_imported = 0 self.check_col_num = check_col_num self.check_validity = check_validity self.line_format = copy.copy(self.LINE_FORMAT) self._initialized = False self._defaults = self.DEFAULTS.copy() self.history_modifier = history_modifier if not self.history_modifier: # get the first admin self.history_modifier = User.objects.filter(is_superuser=True ).order_by('pk')[0] def initialize(self, table): # copy vals in columns vals = [] for idx_line, line in enumerate(table): if (self.skip_first_line and not idx_line): continue for idx_col, val in enumerate(line): if idx_col >= len(self.line_format): break if idx_col >= len(vals): vals.append([]) vals[idx_col].append(val) for idx, formater in enumerate(self.line_format): if formater: formater.init(vals[idx]) self._initialized = True def importation(self, table): self.validity_file = None if not self._initialized: self.initialize(table) if self.check_validity: with NamedTemporaryFile(delete=False) as validity_file: self.validity_file = UnicodeWriter(validity_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) self._importation(table) else: self._importation(table) @classmethod def _field_name_to_data_dict(cls, field_name, value, data, force_value=False, concat=False): field_names = field_name if type(field_names) not in (list, tuple): field_names = [field_name] for field_name in field_names: keys = field_name.split('__') current_data = data for idx, key in enumerate(keys): if idx == (len(keys) - 1): # last if concat: if not value: value = "" current_data[key] = (current_data[key] + u"\n") or u""\ + value elif force_value and value: current_data[key] = value elif key not in current_data or not current_data[key]: current_data[key] = value elif key not in current_data: current_data[key] = {} current_data = current_data[key] return data def _importation(self, table): table = list(table) if not table or not table[0]: raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) if self.check_col_num and len(table[0]) > len(self.line_format): raise ImporterError(self.ERRORS['too_many_cols'] % { 'user_col':len(table[0]), 'ref_col':len(self.line_format)}) self.errors = [] self.messages = [] self.number_imported = 0 # index of the last required column for idx_last_col, formater in enumerate(reversed(self.line_format)): if formater and formater.required: break else: idx_last_col += 1 # min col number to be filled self.min_col_number = len(self.line_format) - idx_last_col # check the conformity with the reference header if self.reference_header and \ self.skip_first_line and \ self.reference_header != table[0]: raise ImporterError(self.ERRORS['header_check'], type=ImporterError.HEADER) self.now = datetime.datetime.now() for idx_line, line in enumerate(table): self._line_processing(idx_line, line) def _line_processing(self, idx_line, line): if (self.skip_first_line and not idx_line): if self.validity_file: self.validity_file.writerow(line) return if not line: if self.validity_file: self.validity_file.writerow([]) return self._throughs = [] # list of (formater, value) self._post_processing = [] # list of (formater, value) data = {} # keep in database the raw line for testing purpose if self.IMPORTED_LINE_FIELD: output = io.StringIO() writer = csv.writer(output) writer.writerow(line) data[self.IMPORTED_LINE_FIELD] = output.getvalue() n = datetime.datetime.now() logger.debug('%s - Processing line %d' % (unicode(n-self.now), idx_line)) self.now = n n2 = n self.c_errors = False c_row = [] for idx_col, val in enumerate(line): try: self._row_processing(c_row, idx_col, idx_line, val, data) except: pass if self.validity_file: self.validity_file.writerow(c_row) if not self.c_errors and (idx_col + 1) < self.min_col_number: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['not_enough_cols'] % self.min_col_number)) if self.c_errors: return n = datetime.datetime.now() logger.debug('* %s - Cols read' % (unicode(n-n2))) n2 = n if self.test: return # manage unicity of items (mainly for updates) self.number_imported += 1 if self.UNICITY_KEYS: data['defaults'] = {} for k in data.keys(): if k not in self.UNICITY_KEYS \ and k != 'defaults': data['defaults'][k] = data.pop(k) data['history_modifier'] = self.history_modifier obj, created = self.get_object(self.OBJECT_CLS, data) if not created and 'defaults' in data: for k in data['defaults']: setattr(obj, k, data['defaults'][k]) obj.save() n = datetime.datetime.now() logger.debug('* %s - Item saved' % (unicode(n-n2))) n2 = n for formater, value in self._throughs: n = datetime.datetime.now() logger.debug('* %s - Processing formater %s' % (unicode(n-n2), formater.field_name)) n2 = n data = {} if formater.through_dict: data = formater.through_dict.copy() if formater.through_key: data[formater.through_key] = obj data[formater.field_name] = value through_cls = formater.through if formater.through_unicity_keys: data['defaults'] = {} for k in data.keys(): if k not in formater.through_unicity_keys \ and k != 'defaults': data['defaults'][k] = data.pop(k) t_obj, created = through_cls.objects.get_or_create(**data) if not created and 'defaults' in data: for k in data['defaults']: setattr(t_obj, k, data['defaults'][k]) t_obj.save() for formater, val in self._post_processing: formater.post_process(obj, data, val, owner=self.history_modifier) def _row_processing(self, c_row, idx_col, idx_line, val, data): if idx_col >= len(self.line_format): return formater = self.line_format[idx_col] if formater and formater.post_processing: self._post_processing.append((formater, val)) if not formater or not formater.field_name: if self.validity_file: c_row.append(val) return # regex management if formater.regexp: # multiline regexp is a mess... val = val.replace('\n', NEW_LINE_BREAK) match = formater.regexp.match(val) if not match: if formater.required: self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) elif not val.strip(): c_row.append("") return self.c_errors = True val = val.replace(NEW_LINE_BREAK, '\n') self.errors.append((idx_line+1, idx_col+1, unicode(self.ERRORS['regex_not_match']) + val)) c_row.append("") return val_group = [v.replace(NEW_LINE_BREAK, '\n') for v in match.groups()] else: val_group = [val] c_values = [] for idx_v, v in enumerate(val_group): self.message = '' func = formater.formater if type(func) in (list, tuple): func = func[idx_v] if not callable(func) and type(func) in (unicode, str): func = getattr(self, func) value = None try: if formater.regexp_formater_args: args = [] for idx in formater.regexp_formater_args[idx_v]: args.append(val_group[idx]) value = func.format(*args) else: value = func.format(v) except ValueError, e: if formater.required: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, e.message)) c_values.append(None) return if self.message: self.messages.append(self.message) c_values.append(value) if value == None: if formater.required: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) return field_name = formater.field_name if type(field_name) in (list, tuple): field_name = field_name[idx_v] field_names = [field_name] if formater.duplicate_field: duplicate_field = formater.duplicate_field if type(duplicate_field) in (list, tuple): duplicate_field = duplicate_field[idx_v] field_names += [duplicate_field] if formater.through: self._throughs.append((formater, value)) else: for field_name in field_names: self._field_name_to_data_dict(field_name, value, data, formater.force_value) if formater.reverse_for_test: c_row.append(formater.reverse_for_test(**c_values)) else: c_row.append(unicode(c_values)) """ def _importation(self, table): table = list(table) if not table or not table[0]: raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) if self.check_col_num and len(table[0]) > len(self.line_format): raise ImporterError(self.ERRORS['too_many_cols'] % { 'user_col':len(table[0]), 'ref_col':len(self.line_format)}) self.errors = [] self.messages = [] self.number_imported = 0 # index of the last required column for idx_last_col, formater in enumerate(reversed(self.line_format)): if formater.required: break else: idx_last_col += 1 # min col number to be filled min_col_number = len(self.line_format) - idx_last_col # check the conformity with the reference header if self.reference_header and \ self.skip_first_line and \ self.reference_header != table[0]: raise ImporterError(self.ERRORS['header_check'], type=ImporterError.HEADER) now = datetime.datetime.now() for idx_line, line in enumerate(table): #self._line_processing() if (self.skip_first_line and not idx_line): if self.validity_file: self.validity_file.writerow(line) continue if not line: if self.validity_file: self.validity_file.writerow([]) continue self.throughs = [] # list of (formater, value) self.post_processing = [] # list of (formater, value) data = {} # keep in database the raw line for testing purpose if self.IMPORTED_LINE_FIELD: output = io.StringIO() writer = csv.writer(output) writer.writerow(line) data[self.IMPORTED_LINE_FIELD] = output.getvalue() n = datetime.datetime.now() logger.debug('%s - Processing line %d' % (unicode(n-now), idx_line)) now = n n2 = n self.c_errors = False c_row = [] for idx_col, val in enumerate(line): #self._row_processing(self, c_row, idx_col, val): if idx_col >= len(self.line_format): break formater = self.line_format[idx_col] if not formater.field_name: if self.validity_file: c_row.append(val) continue if formater.regexp: # multiline regexp is a mess... val = val.replace('\n', NEW_LINE_BREAK) match = formater.regexp.match(val) if not match: if formater.required: self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) elif not val.strip(): c_row.append("") continue c_errors = True val = val.replace(NEW_LINE_BREAK, '\n') self.errors.append((idx_line+1, idx_col+1, unicode(self.ERRORS['regex_not_match']) + val)) c_row.append("") continue val_group = [v.replace(NEW_LINE_BREAK, '\n') for v in match.groups()] else: val_group = [val] c_values = [] for idx_v, v in enumerate(val_group): self.message = '' func = formater.formater if type(func) in (list, tuple): func = func[idx_v] if not callable(func) and type(func) in (unicode, str): func = getattr(self, func) value = None try: if formater.regexp_formater_args: args = [] for idx in formater.regexp_formater_args[idx_v]: args.append(val_group[idx]) value = func.format(*args) else: value = func.format(v) except ValueError, e: c_errors = True self.errors.append((idx_line+1, idx_col+1, e.message)) c_values.append(None) continue if self.message: self.messages.append(self.message) c_values.append(value) if value == None: if formater.required: c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) continue field_name = formater.field_name if type(field_name) in (list, tuple): field_name = field_name[idx_v] field_names = [field_name] if formater.duplicate_field: duplicate_field = formater.duplicate_field if type(duplicate_field) in (list, tuple): duplicate_field = duplicate_field[idx_v] field_names += [duplicate_field] if formater.through: throughs.append((formater, value)) else: for field_name in field_names: self._field_name_to_data_dict(field_name, value, data) if formater.reverse_for_test: c_row.append(formater.reverse_for_test(**c_values)) else: c_row.append(unicode(c_values)) if self.validity_file: self.validity_file.writerow(c_row) if not self.c_errors and (idx_col + 1) < min_col_number: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['not_enough_cols'] % min_col_number)) if self.c_errors: continue n = datetime.datetime.now() logger.debug('* %s - Cols read' % (unicode(n-n2))) n2 = n if self.test: continue # manage unicity of items (mainly for updates) self.number_imported += 1 if self.UNICITY_KEYS: data['defaults'] = {} for k in data.keys(): if k not in self.UNICITY_KEYS \ and k != 'defaults': data['defaults'][k] = data.pop(k) obj, created = self.get_object(self.OBJECT_CLS, data) if not created and 'defaults' in data: for k in data['defaults']: setattr(obj, k, data['defaults'][k]) obj.save() n = datetime.datetime.now() logger.debug('* %s - Item saved' % (unicode(n-n2))) n2 = n for formater, value in self.throughs: n = datetime.datetime.now() logger.debug('* %s - Processing formater %s' % (unicode(n-n2), formater.field_name)) n2 = n data = {} if formater.through_dict: data = formater.through_dict.copy() if formater.through_key: data[formater.through_key] = obj data[formater.field_name] = value through_cls = formater.through if formater.through_unicity_keys: data['defaults'] = {} for k in data.keys(): if k not in formater.through_unicity_keys \ and k != 'defaults': data['defaults'][k] = data.pop(k) t_obj, created = through_cls.objects.get_or_create(**data) if not created and 'defaults' in data: for k in data['defaults']: setattr(t_obj, k, data['defaults'][k]) t_obj.save() """ def get_object(self, cls, data, path=[]): m2ms = [] if data and type(data) == dict: for attribute in data.keys(): c_path = path[:] if not data[attribute]: continue field_object, model, direct, m2m = \ cls._meta.get_field_by_name(attribute) if hasattr(field_object, 'rel') and field_object.rel and \ type(data[attribute]) == dict: c_path.append(attribute) # put history_modifier for every created item data[attribute]['history_modifier'] = \ self.history_modifier data[attribute], created = self.get_object( field_object.rel.to, data[attribute], c_path) if m2m: m2ms.append((attribute, data.pop(attribute))) path = tuple(path) if path in self._defaults: for k in self._defaults[path]: if k not in data or not data[k]: data[k] = self._defaults[path][k] # filter default values create_dict = copy.deepcopy(data) for k in create_dict.keys(): if type(create_dict[k]) == dict: create_dict.pop(k) try: obj, created = cls.objects.get_or_create(**create_dict) for attr, value in m2ms: getattr(obj, attr).add(value) except IntegrityError: raise ImporterError("Erreur d'import %s, contexte : %s" \ % (unicode(cls), unicode(data))) return obj, created return data def get_csv_errors(self): for line, col, error in self.errors: print('"%d","%d","%s"' % (line, col, unicode(error))) @classmethod def choices_check(cls, choices): def function(value): choices_dct = dict(choices) value = value.strip() if not value: return if value not in choices_dct.values(): raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { 'value':value, 'values':u", ".join([val for val in choices_dct.values()]) }) return value return function