diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
| -rw-r--r-- | ishtar_common/data_importer.py | 239 | 
1 files changed, 26 insertions, 213 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 0ce61ba01..87b3a40b0 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -17,7 +17,7 @@  # See the file COPYING for details. -import copy, csv, datetime, logging, sys +import copy, csv, datetime, logging, re, sys  from tempfile import NamedTemporaryFile  from django.contrib.auth.models import User @@ -216,16 +216,16 @@ class StrChoiceFormater(Formater):          msgstr += unicode(_(u"%d. None of the above - skip")) % idx + u"\n"          if self.many_split:              new_values = [] -            r = re.compile(func.many_split) +            r = re.compile(self.many_split)              for value in values:                  new_values += r.split(value)              values = new_values          for value in values: -            base_value = copy(value) +            base_value = copy.copy(value)              value = self.prepare(value)              if value in self.equiv_dict:                  continue -            if not self.cli: +            if output != 'cli':                  self.missings.add(value)                  continue              res = None @@ -263,9 +263,12 @@ class StrChoiceFormater(Formater):  class TypeFormater(StrChoiceFormater):      def __init__(self, model, cli=False, defaults={}, many_split=False): +        self.create = True +        self.strict = False          self.model = model          self.defaults = defaults          self.many_split = many_split +        self.missings = set()          self.equiv_dict, self.choices = {}, []          for item in model.objects.all():              self.choices.append((item.pk, unicode(item))) @@ -348,6 +351,7 @@ class StrToBoolean(Formater):  logger = logging.getLogger(__name__)  class Importer(object): +    DESC = ""      LINE_FORMAT = []      OBJECT_CLS = None      IMPORTED_LINE_FIELD = None @@ -366,19 +370,18 @@ class Importer(object):          'regex_not_match':_(u"The regexp doesn't match.")          } -    def __init__(self, skip_first_line=False, reference_header=None, +    def __init__(self, skip_lines=0, reference_header=None,                   check_col_num=False, test=False, check_validity=True, -                 history_modifier=None, output=None): +                 history_modifier=None, output='silent'):          """ -         * skip_first_line must be set to True if the data provided has got -           an header. +         * skip_line must be set if the data provided has got headers lines.           * a reference_header can be provided to perform a data compliance             check. It can be useful to warn about bad parsing.           * test doesn't write in the database           * check_validity rewrite a CSV file to be compared          """          self.message = '' -        self.skip_first_line = skip_first_line +        self.skip_lines = skip_lines          self.reference_header = reference_header          self.test = test          self.errors = [] # list of (line, col, message) @@ -408,7 +411,7 @@ class Importer(object):          assert output in ('silent', 'cli', 'db')          vals = []          for idx_line, line in enumerate(table): -            if (self.skip_first_line and not idx_line): +            if self.skip_lines > idx_line:                  continue              for idx_col, val in enumerate(line):                  if idx_col >= len(self.line_format): @@ -424,7 +427,7 @@ class Importer(object):      def importation(self, table):          self.validity_file = None          if not self._initialized: -            self.initialize(table) +            self.initialize(table, self.output)          if self.check_validity:              with NamedTemporaryFile(delete=False) as validity_file:                  self.validity_file = UnicodeWriter(validity_file, @@ -479,7 +482,7 @@ class Importer(object):          self.min_col_number = len(self.line_format) - idx_last_col          # check the conformity with the reference header          if self.reference_header and \ -           self.skip_first_line and \ +           self.skip_lines and \             self.reference_header != table[0]:              raise ImporterError(self.ERRORS['header_check'],                                  type=ImporterError.HEADER) @@ -507,7 +510,7 @@ class Importer(object):                  self.errors.append((idx_line, None, msg))      def _line_processing(self, idx_line, line): -        if (self.skip_first_line and not idx_line): +        if self.skip_lines > idx_line:              if self.validity_file:                  self.validity_file.writerow(line)              return @@ -703,201 +706,6 @@ class Importer(object):          else:              c_row.append(unicode(c_values)) - -    """ -    def _importation(self, table): -        table = list(table) -        if not table or not table[0]: -            raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) -        if self.check_col_num and len(table[0]) > len(self.line_format): -            raise ImporterError(self.ERRORS['too_many_cols'] % { -                     'user_col':len(table[0]), 'ref_col':len(self.line_format)}) -        self.errors = [] -        self.messages = [] -        self.number_imported = 0 -        # index of the last required column -        for idx_last_col, formater in enumerate(reversed(self.line_format)): -            if formater.required: -                break -        else: -            idx_last_col += 1 -        # min col number to be filled -        min_col_number = len(self.line_format) - idx_last_col -        # check the conformity with the reference header -        if self.reference_header and \ -           self.skip_first_line and \ -           self.reference_header != table[0]: -            raise ImporterError(self.ERRORS['header_check'], -                                type=ImporterError.HEADER) -        now = datetime.datetime.now() -        for idx_line, line in enumerate(table): -            #self._line_processing() - -            if (self.skip_first_line and not idx_line): -                if self.validity_file: -                    self.validity_file.writerow(line) -                continue -            if not line: -                if self.validity_file: -                    self.validity_file.writerow([]) -                continue -            self.throughs = [] # list of (formater, value) -            self.post_processing = [] # list of (formater, value) -            data = {} - -            # keep in database the raw line for testing purpose -            if self.IMPORTED_LINE_FIELD: -                output = io.StringIO() -                writer = csv.writer(output) -                writer.writerow(line) -                data[self.IMPORTED_LINE_FIELD] = output.getvalue() - -            n = datetime.datetime.now() -            logger.debug('%s - Processing line %d' % (unicode(n-now), idx_line)) -            now = n -            n2 = n -            self.c_errors = False -            c_row = [] -            for idx_col, val in enumerate(line): -                #self._row_processing(self, c_row, idx_col, val): - -                if idx_col >= len(self.line_format): -                    break -                formater = self.line_format[idx_col] -                if not formater.field_name: -                    if self.validity_file: -                        c_row.append(val) -                    continue -                if formater.regexp: -                    # multiline regexp is a mess... -                    val = val.replace('\n', NEW_LINE_BREAK) -                    match = formater.regexp.match(val) -                    if not match: -                        if formater.required: -                            self.errors.append((idx_line+1, idx_col+1, -                                               self.ERRORS['value_required'])) -                        elif not val.strip(): -                            c_row.append("") -                            continue -                        c_errors = True -                        val = val.replace(NEW_LINE_BREAK, '\n') -                        self.errors.append((idx_line+1, idx_col+1, -                                 unicode(self.ERRORS['regex_not_match']) + val)) -                        c_row.append("") -                        continue -                    val_group = [v.replace(NEW_LINE_BREAK, '\n') -                                 for v in match.groups()] -                else: -                    val_group = [val] -                c_values = [] -                for idx_v, v in enumerate(val_group): -                    self.message = '' -                    func = formater.formater -                    if type(func) in (list, tuple): -                        func = func[idx_v] -                    if not callable(func) and type(func) in (unicode, str): -                        func = getattr(self, func) -                    value = None -                    try: -                        if formater.regexp_formater_args: -                            args = [] -                            for idx in formater.regexp_formater_args[idx_v]: -                                args.append(val_group[idx]) -                            value = func.format(*args) -                        else: -                            value = func.format(v) -                    except ValueError, e: -                        c_errors = True -                        self.errors.append((idx_line+1, idx_col+1, e.message)) -                        c_values.append(None) -                        continue -                    if self.message: -                        self.messages.append(self.message) -                    c_values.append(value) -                    if value == None: -                        if formater.required: -                            c_errors = True -                            self.errors.append((idx_line+1, idx_col+1, -                                               self.ERRORS['value_required'])) -                        continue -                    field_name = formater.field_name -                    if type(field_name) in (list, tuple): -                        field_name = field_name[idx_v] -                    field_names = [field_name] -                    if formater.duplicate_field: -                        duplicate_field = formater.duplicate_field -                        if type(duplicate_field) in (list, tuple): -                            duplicate_field = duplicate_field[idx_v] -                        field_names += [duplicate_field] - - -                    if formater.through: -                        throughs.append((formater, value)) -                    else: -                        for field_name in field_names: -                            self._field_name_to_data_dict(field_name, -                                                          value, data) -                if formater.reverse_for_test: -                    c_row.append(formater.reverse_for_test(**c_values)) -                else: -                    c_row.append(unicode(c_values)) - -            if self.validity_file: -                self.validity_file.writerow(c_row) -            if not self.c_errors and (idx_col + 1) < min_col_number: -                self.c_errors = True -                self.errors.append((idx_line+1, idx_col+1, -                               self.ERRORS['not_enough_cols'] % min_col_number)) -            if self.c_errors: -                continue -            n = datetime.datetime.now() -            logger.debug('* %s - Cols read' % (unicode(n-n2))) -            n2 = n -            if self.test: -                continue -            # manage unicity of items (mainly for updates) -            self.number_imported += 1 -            if self.UNICITY_KEYS: -                data['defaults'] = {} -                for k in data.keys(): -                    if k not in self.UNICITY_KEYS \ -                       and k != 'defaults': -                        data['defaults'][k] = data.pop(k) - -            obj, created = self.get_object(self.OBJECT_CLS, data) - -            if not created and 'defaults' in data: -                for k in data['defaults']: -                    setattr(obj, k, data['defaults'][k]) -                obj.save() -            n = datetime.datetime.now() -            logger.debug('* %s - Item saved' % (unicode(n-n2))) -            n2 = n -            for formater, value in self.throughs: -                n = datetime.datetime.now() -                logger.debug('* %s - Processing formater %s' % (unicode(n-n2), -                                                        formater.field_name)) -                n2 = n -                data = {} -                if formater.through_dict: -                    data = formater.through_dict.copy() -                if formater.through_key: -                    data[formater.through_key] = obj -                data[formater.field_name] = value -                through_cls = formater.through -                if formater.through_unicity_keys: -                    data['defaults'] = {} -                    for k in data.keys(): -                        if k not in formater.through_unicity_keys \ -                           and k != 'defaults': -                            data['defaults'][k] = data.pop(k) -                t_obj, created = through_cls.objects.get_or_create(**data) -                if not created and 'defaults' in data: -                    for k in data['defaults']: -                        setattr(t_obj, k, data['defaults'][k]) -                    t_obj.save() -    """ -      def get_object(self, cls, data, path=[]):          m2ms = []          if data and type(data) == dict: @@ -937,6 +745,8 @@ class Importer(object):                      dct = create_dict.copy()                      dct['defaults'] = defaults                      obj, created = cls.objects.get_or_create(**create_dict) +                except IntegrityError as e: +                    raise IntegrityError(e.message)                  except:                      created = False                      obj = cls.objects.filter(**create_dict).all()[0] @@ -946,16 +756,19 @@ class Importer(object):                          values = value                      for v in values:                          getattr(obj, attr).add(v) -            except IntegrityError: -                raise ImporterError("Erreur d'import %s, contexte : %s" \ -                                % (unicode(cls), unicode(data))) +            except IntegrityError as e: +                raise ImporterError("Erreur d'import %s, contexte : %s, erreur : %s" \ +                        % (unicode(cls), unicode(data), e.message.decode('utf-8')))              return obj, created          return data      def get_csv_errors(self): -        csv_errors = [] +        if not self.errors: +            return "" +        csv_errors = ["line,col,error"]          for line, col, error in self.errors: -            csv_errors.append(u'"%d","%d","%s"' % (line or 0, col or 0, +            csv_errors.append(u'"%s","%s","%s"' % (line and unicode(line) or '-', +                                                   col and unicode(col) or '-',                                                          unicode(error)))          return u"\n".join(csv_errors) | 
