diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
| -rw-r--r-- | ishtar_common/data_importer.py | 405 | 
1 files changed, 251 insertions, 154 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 2ab5ba28f..fa8c6a2e0 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -1,6 +1,6 @@  #!/usr/bin/env python  # -*- coding: utf-8 -*- -# Copyright (C) 2013-2015  Étienne Loks  <etienne.loks_AT_peacefrogsDOTnet> +# Copyright (C) 2013-2017 Étienne Loks  <etienne.loks_AT_peacefrogsDOTnet>  # This program is free software: you can redistribute it and/or modify  # it under the terms of the GNU Affero General Public License as @@ -36,11 +36,28 @@ from django.db.models import Q  from django.template.defaultfilters import slugify  from django.utils.translation import ugettext_lazy as _ +from ishtar_common.utils import get_all_field_names + +  NEW_LINE_BREAK = '#####@@@#####'  RE_FILTER_CEDEX = re.compile("(.*) *(?: *CEDEX|cedex|Cedex|Cédex|cédex *\d*)") +def post_importer_action(func): +    def wrapper(self, context, value): +        return func(self, context, value) +    wrapper.importer_trigger = 'post' +    return wrapper + + +def pre_importer_action(func): +    def wrapper(self, context, value): +        return func(self, context, value) +    wrapper.importer_trigger = 'pre' +    return wrapper + +  class ImportFormater(object):      def __init__(self, field_name, formater=None, required=True, through=None,                   through_key=None, through_dict=None, @@ -343,7 +360,7 @@ class StrChoiceFormater(Formater, ChoiceChecker):                  if hasattr(value, 'pk'):                      value = value.pk                  q['value'] = value -                with transaction.commit_on_success(): +                with transaction.atomic():                      try:                          t, created = TargetKey.objects.get_or_create(**q)                      except IntegrityError: @@ -405,7 +422,7 @@ class StrChoiceFormater(Formater, ChoiceChecker):                      target.is_set = True                      target.save()                  else: -                    with transaction.commit_on_success(): +                    with transaction.atomic():                          q['value'] = self.equiv_dict[value]                          q['is_set'] = True                          try: @@ -419,7 +436,7 @@ class StrChoiceFormater(Formater, ChoiceChecker):                       'associated_import': import_instance}                  if TargetKey.objects.filter(**q).count():                      continue -                with transaction.commit_on_success(): +                with transaction.atomic():                      try:                          TargetKey.objects.create(**q)                      except IntegrityError: @@ -473,7 +490,7 @@ class TypeFormater(StrChoiceFormater):          values = copy.copy(self.defaults)          values['label'] = value          values['txt_idx'] = slugify(value) -        if 'order' in self.model._meta.get_all_field_names(): +        if 'order' in get_all_field_names(self.model):              order = 1              q = self.model.objects.values('order').order_by('-order')              if q.count(): @@ -812,7 +829,7 @@ class Importer(object):          self.output = output          if not self.history_modifier:              if self.import_instance: -                self.history_modifier = self.import_instance.user +                self.history_modifier = self.import_instance.user.user_ptr              else:                  # import made by the CLI: get the first admin                  self.history_modifier = User.objects.filter( @@ -825,6 +842,8 @@ class Importer(object):              for related_key in item.RELATED_POST_PROCESS:                  for related in getattr(item, related_key).all():                      related.save() +        for func, context, value in self._item_post_processing: +            getattr(item, func)(context, value)          return item      def initialize(self, table, output='silent', choose_default=False): @@ -990,6 +1009,7 @@ class Importer(object):              return          self._throughs = []  # list of (formater, value)          self._post_processing = []  # list of (formater, value) +        self._item_post_processing = []          data = {}          # keep in database the raw line for testing purpose @@ -1026,8 +1046,7 @@ class Importer(object):          if self.test:              return          # manage unicity of items (mainly for updates) -        if 'history_modifier' in \ -                self.OBJECT_CLS._meta.get_all_field_names(): +        if 'history_modifier' in get_all_field_names(self.OBJECT_CLS):              data['history_modifier'] = self.history_modifier          obj, created = self.get_object(self.OBJECT_CLS, data) @@ -1096,7 +1115,7 @@ class Importer(object):          for formater, val in self._post_processing:              formater.post_process(obj, data, val, owner=self.history_modifier) -        obj = self.post_processing(obj, data) +        self.post_processing(obj, data)      def _row_processing(self, c_row, idx_col, idx_line, val, data):          if idx_col >= len(self.line_format): @@ -1239,162 +1258,241 @@ class Importer(object):                          concat_str=concat_str[idx])          c_row.append(u" ; ".join([v for v in c_values])) +    def _get_field_m2m(self, attribute, data, c_path, new_created, +                       field_object): +        """ +        Manage and m2m field from raw data + +        :param attribute: attribute name +        :param data: current data dictionary +        :param c_path: attribute path from the main model point of view +        :param new_created: dict of forced newly created items to prevent +        multiple creation +        :param field_object: django field object for this attribute +        :return: None +        """ +        m2ms = [] + +        many_values = data.pop(attribute) +        if hasattr(field_object, 'rel'): +            model = field_object.rel.to +        elif hasattr(field_object, 'to'): +            model = field_object.to +        elif hasattr(field_object, 'model'): +            model = field_object.model +        if type(many_values) not in (list, tuple): +            many_values = [many_values] + +        for val in many_values: +            if val.__class__ == model: +                # the value is a model instance: it is OK! +                m2ms.append((attribute, val)) +                continue +            if type(val) != dict: +                # value is not a dict, we don't know what to do with it... +                continue +            vals = [] + +            # contruct many dict for each values +            default_dict = {} + +            # # get default values +            p = [attribute] +            if c_path: +                p = list(c_path) + p +            p = tuple(p) +            if p in self._defaults: +                for k in self._defaults[p]: +                    default_dict[k] = self._defaults[p][k] + +            # # init with simple values that will be duplicated +            for key in val.keys(): +                if type(val[key]) not in (list, tuple): +                    default_dict[key] = val[key] +            vals.append(default_dict.copy()) + +            # # manage multiple values +            for key in val.keys(): +                if type(val[key]) in (list, tuple): +                    for idx, v in enumerate(val[key]): +                        if len(vals) <= idx: +                            vals.append(default_dict.copy()) +                        vals[idx][key] = v + +            # check that m2m are not empty +            notempty = False +            for dct in vals: +                for k in dct: +                    if dct[k] not in ("", None): +                        notempty = True +                        break +            if not notempty: +                continue + +            field_names = get_all_field_names(model) +            for v in vals: +                if 'history_modifier' in field_names: +                    if 'defaults' not in v: +                        v['defaults'] = {} +                    v['defaults']['history_modifier'] = \ +                        self.history_modifier +                m2m_m2ms = [] +                c_c_path = c_path[:] +                for k in v.keys(): +                    if k not in field_names: +                        continue +                    self.get_field(model, k, v, m2m_m2ms, c_c_path, +                                   new_created) +                if '__force_new' in v: +                    created = v.pop('__force_new') +                    key = u";".join([u"{}-{}".format(k, v[k]) +                                     for k in sorted(v.keys())]) +                    # only one forced creation +                    if attribute in new_created \ +                            and key in new_created[attribute]: +                        continue +                    if attribute not in new_created: +                        new_created[attribute] = [] +                    new_created[attribute].append(key) +                    has_values = bool([1 for k in v if v[k]]) +                    if has_values: +                        if self.MODEL_CREATION_LIMIT and \ +                                        model not in self.MODEL_CREATION_LIMIT: +                            raise self._get_improperly_conf_error(model) +                        v = model.objects.create(**v) +                    else: +                        continue +                else: +                    v['defaults'] = v.get('defaults', {}) +                    extra_fields = {} +                    # "File" type is a temp object and can be different +                    # for the same filename - it must be treated +                    # separately +                    for field in model._meta.fields: +                        k = field.name +                        # attr_class is a FileField attribute +                        if hasattr(field, 'attr_class') and k in v: +                            extra_fields[k] = v.pop(k) +                    created = False +                    if not self.MODEL_CREATION_LIMIT or \ +                                    model in self.MODEL_CREATION_LIMIT: +                        v, created = model.objects.get_or_create( +                            **v) +                    else: +                        get_v = v.copy() +                        if 'defaults' in get_v: +                            get_v.pop('defaults') +                        try: +                            v = model.objects.get(**get_v) +                        except model.DoesNotExist: +                            raise self._get_does_not_exist_in_db_error( +                                model, get_v) +                    changed = False +                    for k in extra_fields.keys(): +                        if extra_fields[k]: +                            changed = True +                            setattr(v, k, extra_fields[k]) +                    if changed: +                        v.save() +                for att, objs in m2m_m2ms: +                    if type(objs) not in (list, tuple): +                        objs = [objs] +                    for obj in objs: +                        getattr(v, att).add(obj) +                if self.import_instance \ +                        and hasattr(v, 'imports') and created: +                    v.imports.add(self.import_instance) +                m2ms.append((attribute, v)) +        return m2ms + +    def _set_importer_trigger(self, cls, attribute, data): +        """ +        An importer trigger is used. Stock it for later execution and remove +        it from current data dict. + +        :param cls: current model +        :param attribute: attribute name +        :param data: current data dictionary +        :return: None +        """ +        func = getattr(cls, attribute) +        if func.importer_trigger == 'pre': +            pass  # TODO +        elif func.importer_trigger == 'post': +            self._item_post_processing.append([attribute, data, +                                               data[attribute]]) +        else: +            logger.warning("Unknow importer_trigger '{}' for '{}'".format( +                func.importer_trigger, attribute +            )) +        data.pop(attribute) +      def get_field(self, cls, attribute, data, m2ms, c_path, new_created): +        """ +        Get field from raw data + +        :param cls: current model +        :param attribute: attribute name +        :param data: current data dictionary +        :param m2ms: many to many list of tuple: (m2m key, m2m value) +        :param c_path: attribute path from the main model point of view +        :param new_created: dict of forced newly created items to prevent +        multiple creation +        :return: None +        """ +        if hasattr(cls, attribute) and \ +                getattr(getattr(cls, attribute), 'importer_trigger', None): +            # importer trigger +            self._set_importer_trigger(cls, attribute, data) +            return          try: -            field_object, model, direct, m2m = \ -                cls._meta.get_field_by_name(attribute) +            field_object = cls._meta.get_field(attribute)          except FieldDoesNotExist:              raise ImporterError(unicode(                  _(u"Importer configuration error: field \"{}\" does not exist "                    u"for {}.")).format(attribute, cls._meta.verbose_name)) -        if m2m: -            many_values = data.pop(attribute) -            if hasattr(field_object, 'rel'): -                model = field_object.rel.to -            elif hasattr(field_object, 'to'): -                model = field_object.to -            elif hasattr(field_object, 'model'): -                model = field_object.model -            if type(many_values) not in (list, tuple): -                many_values = [many_values] -            for val in many_values: -                if val.__class__ == model: -                    m2ms.append((attribute, val)) -                elif val.__class__ != model and type(val) == dict: -                    vals = [] - -                    # contruct many dict for each values -                    default_dict = {} - -                    # # get default values -                    p = [attribute] -                    if c_path: -                        p = list(c_path) + p -                    p = tuple(p) -                    if p in self._defaults: -                        for k in self._defaults[p]: -                            default_dict[k] = self._defaults[p][k] -                    # # init with simple values that will be duplicated -                    for key in val.keys(): -                        if type(val[key]) not in (list, tuple): -                            default_dict[key] = val[key] -                    vals.append(default_dict.copy()) -                    # # manage multiple values -                    for key in val.keys(): -                        if type(val[key]) in (list, tuple): -                            for idx, v in enumerate(val[key]): -                                if len(vals) <= idx: -                                    vals.append(default_dict.copy()) -                                vals[idx][key] = v - -                    # check that m2m are not empty -                    notempty = False -                    for dct in vals: -                        for k in dct: -                            if dct[k] not in ("", None): -                                notempty = True -                                break -                    if not notempty: -                        continue - -                    field_names = model._meta.get_all_field_names() -                    for v in vals: -                        if 'history_modifier' in field_names: -                            if 'defaults' not in v: -                                v['defaults'] = {} -                            v['defaults']['history_modifier'] = \ -                                self.history_modifier -                        m2m_m2ms = [] -                        c_c_path = c_path[:] -                        for k in v.keys(): -                            if k not in field_names: -                                continue -                            self.get_field(model, k, v, m2m_m2ms, c_c_path, -                                           new_created) -                        if '__force_new' in v: -                            created = v.pop('__force_new') -                            key = u";".join([u"{}-{}".format(k, v[k]) -                                             for k in sorted(v.keys())]) -                            # only one forced creation -                            if attribute in new_created \ -                                    and key in new_created[attribute]: -                                continue -                            if attribute not in new_created: -                                new_created[attribute] = [] -                            new_created[attribute].append(key) -                            has_values = bool([1 for k in v if v[k]]) -                            if has_values: -                                if self.MODEL_CREATION_LIMIT and \ -                                        model not in self.MODEL_CREATION_LIMIT: -                                    raise self._get_improperly_conf_error(model) -                                v = model.objects.create(**v) -                            else: -                                continue -                        else: -                            v['defaults'] = v.get('defaults', {}) -                            extra_fields = {} -                            # "File" type is a temp object and can be different -                            # for the same filename - it must be treated -                            # separately -                            for field in model._meta.fields: -                                k = field.name -                                # attr_class is a FileField attribute -                                if hasattr(field, 'attr_class') and k in v: -                                    extra_fields[k] = v.pop(k) -                            if not self.MODEL_CREATION_LIMIT or \ -                                    model in self.MODEL_CREATION_LIMIT: -                                v, created = model.objects.get_or_create( -                                    **v) -                            else: -                                get_v = v.copy() -                                if 'defaults' in get_v: -                                    get_v.pop('defaults') -                                try: -                                    v = model.objects.get(**get_v) -                                except model.DoesNotExist: -                                    raise self._get_does_not_exist_in_db_error( -                                        model, get_v) -                            changed = False -                            for k in extra_fields.keys(): -                                if extra_fields[k]: -                                    changed = True -                                    setattr(v, k, extra_fields[k]) -                            if changed: -                                v.save() -                        for att, objs in m2m_m2ms: -                            if type(objs) not in (list, tuple): -                                objs = [objs] -                            for obj in objs: -                                getattr(v, att).add(obj) -                        if self.import_instance \ -                           and hasattr(v, 'imports') and created: -                            v.imports.add(self.import_instance) -                        m2ms.append((attribute, v)) -        elif hasattr(field_object, 'rel') and field_object.rel: -            if type(data[attribute]) == dict: -                # put history_modifier for every created item -                if 'history_modifier' in \ -                   field_object.rel.to._meta.get_all_field_names(): -                    data[attribute]['history_modifier'] = \ -                        self.history_modifier -                try: -                    c_path.append(attribute) -                    data[attribute], created = self.get_object( -                        field_object.rel.to, data[attribute].copy(), c_path) -                except ImporterError, msg: -                    self.errors.append((self.idx_line, None, msg)) -                    data[attribute] = None -            elif type(data[attribute]) == list: -                data[attribute] = data[attribute][0] +        if field_object.many_to_many: +            m2ms += self._get_field_m2m(attribute, data, c_path, +                                        new_created, field_object) +            return +        if not hasattr(field_object, 'rel') or not field_object.rel: +            return +        if type(data[attribute]) == list: +            # extract the first item from list +            # be careful if the list has more than one item this is arbitrary +            if len(data[attribute]) > 1: +                logger.warning( +                    'Import {}: {} has many when only one is expected. Get ' +                    'the first one but it is not OK!'.format( +                     self.import_instance, attribute)) +            data[attribute] = data[attribute][0] +            return +        if type(data[attribute]) != dict: +            # we treat only dict formated values +            return +        # put history_modifier for every created item +        if 'history_modifier' in get_all_field_names(field_object.rel.to): +            data[attribute]['history_modifier'] = \ +                self.history_modifier +        try: +            c_path.append(attribute) +            data[attribute], created = self.get_object( +                field_object.rel.to, data[attribute].copy(), c_path) +        except ImporterError, msg: +            self.errors.append((self.idx_line, None, msg)) +            data[attribute] = None      def get_object(self, cls, data, path=[]):          m2ms = []          if type(data) != dict: +            # if data is not a dict we don't know what to do              return data, False +          is_empty = not bool(              [k for k in data if k not in ('history_modifier', 'defaults')                  and data[k]])          if is_empty: +            # if no value, no creation              return None, False          c_path = path[:] @@ -1408,9 +1506,8 @@ class Importer(object):                      data.pop(attribute)                      continue                  if not data[attribute]: -                    field_object, model, direct, m2m = \ -                        cls._meta.get_field_by_name(attribute) -                    if m2m: +                    field_object = cls._meta.get_field(attribute) +                    if field_object.many_to_many:                          data.pop(attribute)                      continue                  if attribute != '__force_new': @@ -1436,7 +1533,7 @@ class Importer(object):              if type(create_dict[k]) == dict:                  create_dict.pop(k)              # File doesn't like deepcopy -            if type(create_dict[k]) == File: +            elif type(create_dict[k]) == File:                  create_dict[k] = copy.copy(data[k])          # default values | 
