#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (C) 2013-2015 Étienne Loks # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # See the file COPYING for details. import copy, csv, datetime, logging, re, sys from tempfile import NamedTemporaryFile from django.contrib.auth.models import User from django.db import DatabaseError, IntegrityError from django.template.defaultfilters import slugify from django.utils.translation import ugettext_lazy as _ from ishtar_common.unicode_csv import UnicodeWriter NEW_LINE_BREAK = '#####@@@#####' RE_FILTER_CEDEX = re.compile("(.*) *(?: *CEDEX|cedex|Cedex|Cédex|cédex *\d*)") class ImportFormater(object): def __init__(self, field_name, formater=None, required=True, through=None, through_key=None, through_dict=None, through_unicity_keys=None, duplicate_fields=[], regexp=None, regexp_formater_args=[], reverse_for_test=None, force_value=None, post_processing=False, concat=False, comment=""): self.field_name = field_name self.formater = formater self.required = required self.through = through self.through_key = through_key self.through_dict = through_dict self.through_unicity_keys = through_unicity_keys self.duplicate_fields = duplicate_fields self.regexp = regexp self.regexp_formater_args = regexp_formater_args self.reverse_for_test = reverse_for_test # write this value even if a value exists self.force_value = force_value # post process after import self.post_processing = post_processing # concatenate with existing value self.concat = concat self.comment = comment def __unicode__(self): return self.field_name def report_succes(self, *args): return def report_error(self, *args): return def init(self, vals, output=None): try: lst = iter(self.formater) except TypeError: lst = [self.formater] for formater in lst: if formater: formater.check(vals, output) def post_process(self, obj, context, value, owner=None): raise NotImplemented() class ImporterError(Exception): STANDARD = 'S' HEADER = 'H' def __init__(self, message, type='S'): self.msg = message self.type = type def __str__(self): return self.msg class Formater(object): def format(self, value): return value def check(self, values, output=None): return class UnicodeFormater(Formater): def __init__(self, max_length, clean=False, re_filter=None, notnull=False): self.max_length = max_length self.clean = clean self.re_filter = re_filter self.notnull = notnull def format(self, value): try: value = unicode(value.strip()) if self.re_filter: m = self.re_filter.match(value) if m: value = u"".join(m.groups()) if self.clean: if value.startswith(","): value = value[1:] if value.endswith(","): value = value[:-1] value = value.replace(", , ", ", ") except UnicodeDecodeError: return if len(value) > self.max_length: raise ValueError(_(u"\"%(value)s\" is too long. "\ u"The max length is %(length)d characters." ) % {'value':value, 'length':self.max_length}) if self.notnull and not value: return return value class BooleanFormater(Formater): def format(self, value): value = value.strip().upper() if value in ('1', 'OUI', 'VRAI', 'YES', 'TRUE'): return True if value in ('', '0', 'NON', 'FAUX', 'NO', 'FALSE'): return False raise ValueError(_(u"\"%(value)s\" not equal to yes or no") % { 'value':value}) class FloatFormater(Formater): def format(self, value): value = value.strip().replace(',', '.') if not value: return try: return float(value) except ValueError: raise ValueError(_(u"\"%(value)s\" is not a float") % { 'value':value}) class YearFormater(Formater): def format(self, value): value = value.strip() if not value: return try: value = int(value) assert value > 0 and value < (datetime.date.today().year + 30) except (ValueError, AssertionError): raise ValueError(_(u"\"%(value)s\" is not a valid date") % { 'value':value}) class YearNoFuturFormater(Formater): def format(self, value): value = value.strip() if not value: return try: value = int(value) assert value > 0 and value < (datetime.date.today().year) except (ValueError, AssertionError): raise ValueError(_(u"\"%(value)s\" is not a valid date") % { 'value':value}) class IntegerFormater(Formater): def format(self, value): value = value.strip() if not value: return try: return int(value) except ValueError: raise ValueError(_(u"\"%(value)s\" is not an integer") % { 'value':value}) class StrChoiceFormater(Formater): def __init__(self, choices, strict=False, equiv_dict={}, model=None, cli=False, many_split=''): self.choices = list(choices) self.strict = strict self.equiv_dict = copy.deepcopy(equiv_dict) self.cli = cli self.model = model self.create = False self.missings = set() self.many_split = many_split for key, value in self.choices: value = unicode(value) if not self.strict: value = slugify(value) if value not in self.equiv_dict: v = key if model and v: v = model.objects.get(pk=v) self.equiv_dict[value] = v def prepare(self, value): return unicode(value).strip() def _get_choices(self): msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ u"Which one is relevant?\n")) idx = -1 for idx, choice in enumerate(self.choices): msgstr += u"%d. %s\n" % (idx+1, choice[1]) idx += 2 if self.create: msgstr += unicode(_(u"%d. None of the above - create new")) % idx \ + u"\n" idx += 1 msgstr += unicode(_(u"%d. None of the above - skip")) % idx + u"\n" return msgstr, idx def check(self, values, output=None): if not output or output == 'silent': return if self.many_split: new_values = [] r = re.compile(self.many_split) for value in values: new_values += r.split(value) values = new_values for value in values: base_value = copy.copy(value) value = self.prepare(value) if value in self.equiv_dict: continue if output != 'cli': self.missings.add(value) continue msgstr, idx = self._get_choices() res = None while res not in range(1, idx+1): sys.stdout.write(msgstr % value) res = raw_input(">>> ") try: res = int(res) except ValueError: pass res -= 1 if res < len(self.choices): v = self.choices[res][0] if self.model and v: v = self.model.objects.get(pk=v) self.equiv_dict[value] = v self.add_key(v, value) elif self.create and res == len(self.choices): self.equiv_dict[value] = self.new(base_value) self.choices.append((self.equiv_dict[value].pk, unicode(self.equiv_dict[value]))) else: self.equiv_dict[value] = None def new(self, value): return def add_key(self, obj, value): return def format(self, value): value = self.prepare(value) if not self.strict: value = slugify(value) if value in self.equiv_dict: return self.equiv_dict[value] class TypeFormater(StrChoiceFormater): def __init__(self, model, cli=False, defaults={}, many_split=False): self.create = True self.strict = False self.model = model self.defaults = defaults self.many_split = many_split self.missings = set() self.equiv_dict, self.choices = {}, [] for item in model.objects.all(): self.choices.append((item.pk, unicode(item))) for key in item.get_keys(): self.equiv_dict[key] = item def prepare(self, value): return slugify(unicode(value).strip()) def add_key(self, obj, value): obj.add_key(slugify(value), force=True) def new(self, value): values = copy.copy(self.defaults) values['label'] = value values['txt_idx'] = slugify(value) if 'order' in self.model._meta.get_all_field_names(): order = 1 q = self.model.objects.values('order').order_by('-order') if q.count(): order = q.all()[0]['order'] or 1 values['order'] = order return self.model.objects.create(**values) class DateFormater(Formater): def __init__(self, date_format="%d/%m/%Y"): self.date_format = date_format def format(self, value): value = value.strip() if not value: return try: return datetime.datetime.strptime(value, self.date_format).date() except: raise ValueError(_(u"\"%(value)s\" is not a valid date") % { 'value':value}) class StrToBoolean(Formater): def __init__(self, choices={}, cli=False, strict=False): self.dct = copy.copy(choices) self.cli = cli self.strict= strict self.missings = set() def prepare(self, value): value = unicode(value).strip() if not self.strict: value = slugify(value) return value def check(self, values, output=None): if not output or output == 'silent': return msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ u"Which one is relevant?\n")) msgstr += u"1. True\n" msgstr += u"2. False\n" msgstr += u"3. Empty\n" for value in values: value = self.prepare(value) if value in self.dct: continue if not self.cli: self.missings.add(value) continue res = None while res not in range(1, 4): sys.stdout.write(msgstr % value) res = raw_input(">>> ") try: res = int(res) except ValueError: pass if res == 1: self.dct[value] = True elif res == 2: self.dct[value] = False else: self.dct[value] = None def format(self, value): value = self.prepare(value) if value in self.dct: return self.dct[value] logger = logging.getLogger(__name__) class Importer(object): DESC = "" LINE_FORMAT = [] OBJECT_CLS = None IMPORTED_LINE_FIELD = None UNICITY_KEYS = [] DEFAULTS = {} ERRORS = { 'header_check':_(u"The given file is not correct. Check the file " u"format. If you use a CSV file: check that column separator " u"and encoding are similar to the ones used by the reference " u"file."), 'too_many_cols':_(u"Too many cols (%(user_col)d) when " u"maximum is %(ref_col)d"), 'no_data':_(u"No data provided"), 'value_required':_(u"Value is required"), 'not_enough_cols':_(u"At least %d columns must be filled"), 'regex_not_match':_(u"The regexp doesn't match.") } def __init__(self, skip_lines=0, reference_header=None, check_col_num=False, test=False, check_validity=True, history_modifier=None, output='silent'): """ * skip_line must be set if the data provided has got headers lines. * a reference_header can be provided to perform a data compliance check. It can be useful to warn about bad parsing. * test doesn't write in the database * check_validity rewrite a CSV file to be compared """ self.message = '' self.skip_lines = skip_lines self.reference_header = reference_header self.test = test self.errors = [] # list of (line, col, message) self.messages = [] # list of (line, col, message) self.number_updated = 0 self.number_created = 0 self.check_col_num = check_col_num self.check_validity = check_validity self.line_format = copy.copy(self.LINE_FORMAT) self._initialized = False self._defaults = self.DEFAULTS.copy() self.history_modifier = history_modifier self.output = output if not self.history_modifier: # get the first admin self.history_modifier = User.objects.filter(is_superuser=True ).order_by('pk')[0] def initialize(self, table, output='silent'): """ copy vals in columns and initialize formaters * output: - 'silent': no associations - 'cli': output by command line interface and stocked in the database - 'db': output on the database with no interactive association (further exploitation by web interface) """ assert output in ('silent', 'cli', 'db') vals = [] for idx_line, line in enumerate(table): if self.skip_lines > idx_line: continue for idx_col, val in enumerate(line): if idx_col >= len(self.line_format): break if idx_col >= len(vals): vals.append([]) vals[idx_col].append(val) for idx, formater in enumerate(self.line_format): if formater: formater.init(vals[idx], output) self._initialized = True def importation(self, table): self.validity_file = None if not self._initialized: self.initialize(table, self.output) if self.check_validity: with NamedTemporaryFile(delete=False) as validity_file: self.validity_file = UnicodeWriter(validity_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) self._importation(table) else: self._importation(table) @classmethod def _field_name_to_data_dict(cls, field_name, value, data, force_value=False, concat=False): field_names = field_name if type(field_names) not in (list, tuple): field_names = [field_name] for field_name in field_names: keys = field_name.split('__') current_data = data for idx, key in enumerate(keys): if idx == (len(keys) - 1): # last if concat: if not value: value = "" current_data[key] = (current_data[key] + u"\n") or u""\ + value elif force_value and value: current_data[key] = value elif key not in current_data or not current_data[key]: current_data[key] = value elif key not in current_data: current_data[key] = {} current_data = current_data[key] return data def _importation(self, table): table = list(table) if not table or not table[0]: raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) if self.check_col_num and len(table[0]) > len(self.line_format): raise ImporterError(self.ERRORS['too_many_cols'] % { 'user_col':len(table[0]), 'ref_col':len(self.line_format)}) self.errors = [] self.messages = [] self.number_imported = 0 # index of the last required column for idx_last_col, formater in enumerate(reversed(self.line_format)): if formater and formater.required: break else: idx_last_col += 1 # min col number to be filled self.min_col_number = len(self.line_format) - idx_last_col # check the conformity with the reference header if self.reference_header and \ self.skip_lines and \ self.reference_header != table[0]: raise ImporterError(self.ERRORS['header_check'], type=ImporterError.HEADER) self.now = datetime.datetime.now() start = datetime.datetime.now() total = len(table) if self.output: sys.stdout.write("\n") for idx_line, line in enumerate(table): if self.output: left = None if idx_line > 10: ellapsed = datetime.datetime.now() - start time_by_item = ellapsed/idx_line if time_by_item: left = ((total - idx_line)*time_by_item).seconds txt = "\r* %d/%d" % (idx_line+1, total) if left: txt += " (%d seconds left)" % left sys.stdout.write(txt) sys.stdout.flush() try: self._line_processing(idx_line, line) except ImporterError, msg: self.errors.append((idx_line, None, msg)) def _line_processing(self, idx_line, line): if self.skip_lines > idx_line: if self.validity_file: self.validity_file.writerow(line) return if not line: if self.validity_file: self.validity_file.writerow([]) return self._throughs = [] # list of (formater, value) self._post_processing = [] # list of (formater, value) data = {} # keep in database the raw line for testing purpose if self.IMPORTED_LINE_FIELD: output = io.StringIO() writer = csv.writer(output) writer.writerow(line) data[self.IMPORTED_LINE_FIELD] = output.getvalue() n = datetime.datetime.now() logger.debug('%s - Processing line %d' % (unicode(n-self.now), idx_line)) self.now = n n2 = n self.c_errors = False c_row = [] for idx_col, val in enumerate(line): try: self._row_processing(c_row, idx_col, idx_line, val, data) except: pass if self.validity_file: self.validity_file.writerow(c_row) if not self.c_errors and (idx_col + 1) < self.min_col_number: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['not_enough_cols'] % self.min_col_number)) if self.c_errors: return n = datetime.datetime.now() logger.debug('* %s - Cols read' % (unicode(n-n2))) n2 = n if self.test: return # manage unicity of items (mainly for updates) if self.UNICITY_KEYS: data['defaults'] = {} for k in data.keys(): if k not in self.UNICITY_KEYS \ and k != 'defaults': data['defaults'][k] = data.pop(k) if 'history_modifier' in \ self.OBJECT_CLS._meta.get_all_field_names(): data['history_modifier'] = self.history_modifier obj, created = self.get_object(self.OBJECT_CLS, data) if created: self.number_created += 1 else: self.number_updated += 1 if not created and 'defaults' in data: for k in data['defaults']: setattr(obj, k, data['defaults'][k]) obj.save() n = datetime.datetime.now() logger.debug('* %s - Item saved' % (unicode(n-n2))) n2 = n for formater, value in self._throughs: n = datetime.datetime.now() logger.debug('* %s - Processing formater %s' % (unicode(n-n2), formater.field_name)) n2 = n data = {} if formater.through_dict: data = formater.through_dict.copy() if formater.through_key: data[formater.through_key] = obj data[formater.field_name] = value through_cls = formater.through if formater.through_unicity_keys: data['defaults'] = {} for k in data.keys(): if k not in formater.through_unicity_keys \ and k != 'defaults': data['defaults'][k] = data.pop(k) t_obj, created = through_cls.objects.get_or_create(**data) if not created and 'defaults' in data: for k in data['defaults']: setattr(t_obj, k, data['defaults'][k]) t_obj.save() for formater, val in self._post_processing: formater.post_process(obj, data, val, owner=self.history_modifier) def _row_processing(self, c_row, idx_col, idx_line, val, data): if idx_col >= len(self.line_format): return formater = self.line_format[idx_col] if formater and formater.post_processing: self._post_processing.append((formater, val)) if not formater or not formater.field_name: if self.validity_file: c_row.append(val) return # regex management if formater.regexp: # multiline regexp is a mess... val = val.replace('\n', NEW_LINE_BREAK) match = formater.regexp.match(val) if not match: if formater.required: self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) elif not val.strip(): c_row.append("") return self.c_errors = True val = val.replace(NEW_LINE_BREAK, '\n') self.errors.append((idx_line+1, idx_col+1, unicode(self.ERRORS['regex_not_match']) + val)) c_row.append("") return val_group = [v.replace(NEW_LINE_BREAK, '\n') for v in match.groups()] else: val_group = [val] c_values = [] for idx_v, v in enumerate(val_group): self.message = '' func = formater.formater if type(func) in (list, tuple): func = func[idx_v] if not callable(func) and type(func) in (unicode, str): func = getattr(self, func) values = [v] many_values = getattr(func, 'many_split', None) if many_values: values = re.split(func.many_split, values[0]) formated_values = [] for idx, v in enumerate(values): value = None try: if formater.regexp_formater_args: args = [] for idx in formater.regexp_formater_args[idx_v]: args.append(val_group[idx]) value = func.format(*args) else: value = func.format(v) except ValueError, e: if formater.required: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, e.message)) c_values.append(None) return formated_values.append(value) if self.message: self.messages.append(self.message) value = formated_values if not many_values: value = formated_values[0] c_values.append(value) if value == None and formater.required: self.c_errors = True self.errors.append((idx_line+1, idx_col+1, self.ERRORS['value_required'])) return field_name = formater.field_name if type(field_name) in (list, tuple): field_name = field_name[idx_v] field_names = [field_name] for duplicate_field in formater.duplicate_fields: if type(duplicate_field) in (list, tuple): duplicate_field = duplicate_field[idx_v] field_names += [duplicate_field] if formater.through: self._throughs.append((formater, value)) else: for field_name in field_names: self._field_name_to_data_dict(field_name, value, data, formater.force_value) if formater.reverse_for_test: c_row.append(formater.reverse_for_test(**c_values)) else: c_row.append(unicode(c_values)) def get_object(self, cls, data, path=[]): m2ms = [] if data and type(data) == dict: for attribute in data.keys(): c_path = path[:] if not data[attribute]: continue field_object, model, direct, m2m = \ cls._meta.get_field_by_name(attribute) if hasattr(field_object, 'rel') and field_object.rel and \ type(data[attribute]) == dict: c_path.append(attribute) # put history_modifier for every created item if 'history_modifier' in \ field_object.rel.to._meta.get_all_field_names(): data[attribute]['history_modifier'] = \ self.history_modifier data[attribute], created = self.get_object( field_object.rel.to, data[attribute], c_path) if m2m: val = data.pop(attribute) model = field_object.model if val.__class__ != model and type(val) == dict: if 'history_modifier' in \ model._meta.get_all_field_names(): val['history_modifier'] = self.history_modifier val, created = field_object.model.objects.get_or_create( **val) m2ms.append((attribute, val)) path = tuple(path) if path in self._defaults: for k in self._defaults[path]: if k not in data or not data[k]: data[k] = self._defaults[path][k] # filter default values create_dict = copy.deepcopy(data) for k in create_dict.keys(): if type(create_dict[k]) == dict: create_dict.pop(k) defaults = {} if 'history_modifier' in create_dict: defaults = {'history_modifier':create_dict.pop('history_modifier')} try: try: dct = create_dict.copy() dct['defaults'] = defaults obj, created = cls.objects.get_or_create(**create_dict) except IntegrityError as e: raise IntegrityError(e.message) except: created = False obj = cls.objects.filter(**create_dict).all()[0] for attr, value in m2ms: values = [value] if type(value) in (list, tuple): values = value for v in values: getattr(obj, attr).add(v) except IntegrityError as e: raise ImporterError("Erreur d'import %s, contexte : %s, erreur : %s" \ % (unicode(cls), unicode(data), e.message.decode('utf-8'))) return obj, created return data def get_csv_errors(self): if not self.errors: return "" csv_errors = ["line,col,error"] for line, col, error in self.errors: csv_errors.append(u'"%s","%s","%s"' % (line and unicode(line) or '-', col and unicode(col) or '-', unicode(error))) return u"\n".join(csv_errors) @classmethod def choices_check(cls, choices): def function(value): choices_dct = dict(choices) value = value.strip() if not value: return if value not in choices_dct.values(): raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { 'value':value, 'values':u", ".join([val for val in choices_dct.values()]) }) return value return function