diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
| -rw-r--r-- | ishtar_common/data_importer.py | 509 | 
1 files changed, 322 insertions, 187 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 2ab5ba28f..00aa34f4c 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -1,6 +1,6 @@  #!/usr/bin/env python  # -*- coding: utf-8 -*- -# Copyright (C) 2013-2015  Étienne Loks  <etienne.loks_AT_peacefrogsDOTnet> +# Copyright (C) 2013-2017 Étienne Loks  <etienne.loks_AT_peacefrogsDOTnet>  # This program is free software: you can redistribute it and/or modify  # it under the terms of the GNU Affero General Public License as @@ -36,11 +36,28 @@ from django.db.models import Q  from django.template.defaultfilters import slugify  from django.utils.translation import ugettext_lazy as _ +from ishtar_common.utils import get_all_field_names + +  NEW_LINE_BREAK = '#####@@@#####'  RE_FILTER_CEDEX = re.compile("(.*) *(?: *CEDEX|cedex|Cedex|Cédex|cédex *\d*)") +def post_importer_action(func): +    def wrapper(self, context, value): +        return func(self, context, value) +    wrapper.importer_trigger = 'post' +    return wrapper + + +def pre_importer_action(func): +    def wrapper(self, context, value): +        return func(self, context, value) +    wrapper.importer_trigger = 'pre' +    return wrapper + +  class ImportFormater(object):      def __init__(self, field_name, formater=None, required=True, through=None,                   through_key=None, through_dict=None, @@ -74,20 +91,20 @@ class ImportFormater(object):          self.force_new = force_new          self.label = label -    def reinit_db_target(self, db_target, nb=0): +    def reinit_db_target(self, db_target, nb=0, user=None):          if not self.formater:              return          if type(db_target) in (list, tuple):              db_target = db_target[nb]          if type(self.formater) not in (list, tuple):              self.formater.db_target = db_target -            self.formater.init_db_target() +            self.formater.init_db_target(user=user)          else:              for idx, formater in enumerate(self.formater):                  formater.db_target = db_target -                formater.init_db_target() +                formater.init_db_target(user=user) -    def init_db_target(self): +    def init_db_target(self, user=None):          pass      def __unicode__(self): @@ -100,7 +117,7 @@ class ImportFormater(object):          return      def init(self, vals, output=None, choose_default=False, -             import_instance=None): +             import_instance=None, user=None):          try:              lst = iter(self.formater)          except TypeError: @@ -109,7 +126,8 @@ class ImportFormater(object):              if formater:                  formater.check(vals, output, self.comment,                                 choose_default=choose_default, -                               import_instance=import_instance) +                               import_instance=import_instance, +                               user=user)      def post_process(self, obj, context, value, owner=None):          raise NotImplemented() @@ -135,12 +153,39 @@ class Formater(object):          return value      def check(self, values, output=None, comment='', choose_default=False, -              import_instance=None): +              import_instance=None, user=None):          return -    def init_db_target(self): +    def init_db_target(self, user=None):          pass +    def _base_target_filter(self, user=None): +        # set for all users +        q_or = ( +            Q(associated_import__isnull=True) & +            Q(associated_user__isnull=True) & +            Q(associated_group__isnull=True) +        ) +        if hasattr(self, 'import_instance') and self.import_instance: +            # set for current import +            q_or = q_or | Q(associated_import=self.import_instance) +            if self.import_instance.associated_group: +                # set for associated group +                q_or = q_or | Q( +                    associated_group=self.import_instance.associated_group) +        if user: +            # set for current user +            q_or = q_or | Q(associated_user=user) +        return q_or + +    def get_db_target_query(self, user=None): +        if not self.db_target: +            return +        q = self.db_target.keys.filter(is_set=True) +        q_or = self._base_target_filter(user) +        q = q.filter(q_or) +        return q +  class ChoiceChecker(object):      def report_new(self, comment): @@ -290,13 +335,12 @@ class StrChoiceFormater(Formater, ChoiceChecker):                  self.equiv_dict[value] = v          self.init_db_target() -    def init_db_target(self): +    def init_db_target(self, user=None):          if not self.db_target:              return -        q = self.db_target.keys.filter(is_set=True) -        if self.import_instance: -            q = q.filter(Q(associated_import=self.import_instance) | -                         Q(associated_import__isnull=True)) + +        q = self.get_db_target_query(user) +          for target_key in q.all():              key = target_key.key              if not self.strict: @@ -330,7 +374,10 @@ class StrChoiceFormater(Formater, ChoiceChecker):          return msgstr, idx      def check(self, values, output=None, comment='', choose_default=False, -              import_instance=None): +              import_instance=None, user=None): +        self.init_db_target(user) + +        """          from ishtar_common.models import TargetKey          if self.db_target:              q = {'target': self.db_target, @@ -343,12 +390,12 @@ class StrChoiceFormater(Formater, ChoiceChecker):                  if hasattr(value, 'pk'):                      value = value.pk                  q['value'] = value -                with transaction.commit_on_success(): +                with transaction.atomic():                      try:                          t, created = TargetKey.objects.get_or_create(**q)                      except IntegrityError:                          pass - +        """          if (not output or output == 'silent') and not choose_default:              return          if self.many_split: @@ -395,17 +442,17 @@ class StrChoiceFormater(Formater, ChoiceChecker):                  self.equiv_dict[value] = None              if self.equiv_dict[value] and self.db_target:                  from ishtar_common.models import TargetKey -                q = {'target': self.db_target, 'key': value, -                     'associated_import': import_instance, -                     } +                q = {'target': self.db_target, 'key': value}                  query = TargetKey.objects.filter(**q) +                query = query.filter(self._base_target_filter(user))                  if query.count():                      target = query.all()[0]                      target.value = self.equiv_dict[value]                      target.is_set = True                      target.save()                  else: -                    with transaction.commit_on_success(): +                    q['associated_import'] = import_instance +                    with transaction.atomic():                          q['value'] = self.equiv_dict[value]                          q['is_set'] = True                          try: @@ -415,11 +462,13 @@ class StrChoiceFormater(Formater, ChoiceChecker):          if output == 'db' and self.db_target:              from ishtar_common.models import TargetKey              for missing in self.missings: -                q = {'target': self.db_target, 'key': missing, -                     'associated_import': import_instance} -                if TargetKey.objects.filter(**q).count(): +                q = {'target': self.db_target, 'key': missing} +                query = TargetKey.objects.filter(**q) +                query = query.filter(self._base_target_filter(user)) +                if query.count():                      continue -                with transaction.commit_on_success(): +                with transaction.atomic(): +                    q['associated_import'] = import_instance                      try:                          TargetKey.objects.create(**q)                      except IntegrityError: @@ -473,7 +522,7 @@ class TypeFormater(StrChoiceFormater):          values = copy.copy(self.defaults)          values['label'] = value          values['txt_idx'] = slugify(value) -        if 'order' in self.model._meta.get_all_field_names(): +        if 'order' in get_all_field_names(self.model):              order = 1              q = self.model.objects.values('order').order_by('-order')              if q.count(): @@ -538,15 +587,16 @@ class StrToBoolean(Formater, ChoiceChecker):          self.strict = strict          self.db_target = db_target          self.missings = set() -        self.init_db_target()          self.match_table = {}          self.new_keys = {}          self.import_instance = import_instance +        self.init_db_target() -    def init_db_target(self): +    def init_db_target(self, user=None):          if not self.db_target:              return -        for target_key in self.db_target.keys.filter(is_set=True).all(): +        q = self.get_db_target_query(user) +        for target_key in q.all():              key = self.prepare(target_key.key)              if key in self.dct:                  continue @@ -560,7 +610,7 @@ class StrToBoolean(Formater, ChoiceChecker):          return value      def check(self, values, output=None, comment='', choose_default=False, -              import_instance=None): +              import_instance=None, user=None):          if (not output or output == 'silent') and not choose_default:              return          msgstr = comment + u" - " @@ -812,7 +862,7 @@ class Importer(object):          self.output = output          if not self.history_modifier:              if self.import_instance: -                self.history_modifier = self.import_instance.user +                self.history_modifier = self.import_instance.user.user_ptr              else:                  # import made by the CLI: get the first admin                  self.history_modifier = User.objects.filter( @@ -825,16 +875,20 @@ class Importer(object):              for related_key in item.RELATED_POST_PROCESS:                  for related in getattr(item, related_key).all():                      related.save() +        for func, context, value in self._item_post_processing: +            getattr(item, func)(context, value)          return item -    def initialize(self, table, output='silent', choose_default=False): +    def initialize(self, table, output='silent', choose_default=False, +                   user=None):          """          copy vals in columns and initialize formaters          * output: -         - 'silent': no associations -         - 'cli': output by command line interface and stocked in the database -         - 'db': output on the database with no interactive association +         - silent: no associations +         - cli: output by command line interface and stocked in the database +         - db: output on the database with no interactive association             (further exploitation by web interface) +         - user: associated user          """          assert output in ('silent', 'cli', 'db')          vals = [] @@ -859,14 +913,17 @@ class Importer(object):                          db_targets.append(                              self.DB_TARGETS["{}-{}".format(                                  idx + 1, field_name)]) -                    formater.reinit_db_target(db_targets) +                    formater.reinit_db_target(db_targets, user=user)                  formater.init(vals[idx], output, choose_default=choose_default, -                              import_instance=self.import_instance) +                              import_instance=self.import_instance, +                              user=user) -    def importation(self, table, initialize=True, choose_default=False): +    def importation(self, table, initialize=True, choose_default=False, +                    user=None):          if initialize: -            self.initialize(table, self.output, choose_default=choose_default) +            self.initialize(table, self.output, +                            choose_default=choose_default, user=user)          self._importation(table)      def _associate_db_target_to_formaters(self): @@ -990,6 +1047,7 @@ class Importer(object):              return          self._throughs = []  # list of (formater, value)          self._post_processing = []  # list of (formater, value) +        self._item_post_processing = []          data = {}          # keep in database the raw line for testing purpose @@ -1026,8 +1084,7 @@ class Importer(object):          if self.test:              return          # manage unicity of items (mainly for updates) -        if 'history_modifier' in \ -                self.OBJECT_CLS._meta.get_all_field_names(): +        if 'history_modifier' in get_all_field_names(self.OBJECT_CLS):              data['history_modifier'] = self.history_modifier          obj, created = self.get_object(self.OBJECT_CLS, data) @@ -1096,7 +1153,7 @@ class Importer(object):          for formater, val in self._post_processing:              formater.post_process(obj, data, val, owner=self.history_modifier) -        obj = self.post_processing(obj, data) +        self.post_processing(obj, data)      def _row_processing(self, c_row, idx_col, idx_line, val, data):          if idx_col >= len(self.line_format): @@ -1239,162 +1296,241 @@ class Importer(object):                          concat_str=concat_str[idx])          c_row.append(u" ; ".join([v for v in c_values])) +    def _get_field_m2m(self, attribute, data, c_path, new_created, +                       field_object): +        """ +        Manage and m2m field from raw data + +        :param attribute: attribute name +        :param data: current data dictionary +        :param c_path: attribute path from the main model point of view +        :param new_created: dict of forced newly created items to prevent +        multiple creation +        :param field_object: django field object for this attribute +        :return: None +        """ +        m2ms = [] + +        many_values = data.pop(attribute) +        if hasattr(field_object, 'rel'): +            model = field_object.rel.to +        elif hasattr(field_object, 'to'): +            model = field_object.to +        elif hasattr(field_object, 'model'): +            model = field_object.model +        if type(many_values) not in (list, tuple): +            many_values = [many_values] + +        for val in many_values: +            if val.__class__ == model: +                # the value is a model instance: it is OK! +                m2ms.append((attribute, val)) +                continue +            if type(val) != dict: +                # value is not a dict, we don't know what to do with it... +                continue +            vals = [] + +            # contruct many dict for each values +            default_dict = {} + +            # # get default values +            p = [attribute] +            if c_path: +                p = list(c_path) + p +            p = tuple(p) +            if p in self._defaults: +                for k in self._defaults[p]: +                    default_dict[k] = self._defaults[p][k] + +            # # init with simple values that will be duplicated +            for key in val.keys(): +                if type(val[key]) not in (list, tuple): +                    default_dict[key] = val[key] +            vals.append(default_dict.copy()) + +            # # manage multiple values +            for key in val.keys(): +                if type(val[key]) in (list, tuple): +                    for idx, v in enumerate(val[key]): +                        if len(vals) <= idx: +                            vals.append(default_dict.copy()) +                        vals[idx][key] = v + +            # check that m2m are not empty +            notempty = False +            for dct in vals: +                for k in dct: +                    if dct[k] not in ("", None): +                        notempty = True +                        break +            if not notempty: +                continue + +            field_names = get_all_field_names(model) +            for v in vals: +                if 'history_modifier' in field_names: +                    if 'defaults' not in v: +                        v['defaults'] = {} +                    v['defaults']['history_modifier'] = \ +                        self.history_modifier +                m2m_m2ms = [] +                c_c_path = c_path[:] +                for k in v.keys(): +                    if k not in field_names: +                        continue +                    self.get_field(model, k, v, m2m_m2ms, c_c_path, +                                   new_created) +                if '__force_new' in v: +                    created = v.pop('__force_new') +                    key = u";".join([u"{}-{}".format(k, v[k]) +                                     for k in sorted(v.keys())]) +                    # only one forced creation +                    if attribute in new_created \ +                            and key in new_created[attribute]: +                        continue +                    if attribute not in new_created: +                        new_created[attribute] = [] +                    new_created[attribute].append(key) +                    has_values = bool([1 for k in v if v[k]]) +                    if has_values: +                        if self.MODEL_CREATION_LIMIT and \ +                                        model not in self.MODEL_CREATION_LIMIT: +                            raise self._get_improperly_conf_error(model) +                        v = model.objects.create(**v) +                    else: +                        continue +                else: +                    v['defaults'] = v.get('defaults', {}) +                    extra_fields = {} +                    # "File" type is a temp object and can be different +                    # for the same filename - it must be treated +                    # separately +                    for field in model._meta.fields: +                        k = field.name +                        # attr_class is a FileField attribute +                        if hasattr(field, 'attr_class') and k in v: +                            extra_fields[k] = v.pop(k) +                    created = False +                    if not self.MODEL_CREATION_LIMIT or \ +                                    model in self.MODEL_CREATION_LIMIT: +                        v, created = model.objects.get_or_create( +                            **v) +                    else: +                        get_v = v.copy() +                        if 'defaults' in get_v: +                            get_v.pop('defaults') +                        try: +                            v = model.objects.get(**get_v) +                        except model.DoesNotExist: +                            raise self._get_does_not_exist_in_db_error( +                                model, get_v) +                    changed = False +                    for k in extra_fields.keys(): +                        if extra_fields[k]: +                            changed = True +                            setattr(v, k, extra_fields[k]) +                    if changed: +                        v.save() +                for att, objs in m2m_m2ms: +                    if type(objs) not in (list, tuple): +                        objs = [objs] +                    for obj in objs: +                        getattr(v, att).add(obj) +                if self.import_instance \ +                        and hasattr(v, 'imports') and created: +                    v.imports.add(self.import_instance) +                m2ms.append((attribute, v)) +        return m2ms + +    def _set_importer_trigger(self, cls, attribute, data): +        """ +        An importer trigger is used. Stock it for later execution and remove +        it from current data dict. + +        :param cls: current model +        :param attribute: attribute name +        :param data: current data dictionary +        :return: None +        """ +        func = getattr(cls, attribute) +        if func.importer_trigger == 'pre': +            pass  # TODO +        elif func.importer_trigger == 'post': +            self._item_post_processing.append([attribute, data, +                                               data[attribute]]) +        else: +            logger.warning("Unknow importer_trigger '{}' for '{}'".format( +                func.importer_trigger, attribute +            )) +        data.pop(attribute) +      def get_field(self, cls, attribute, data, m2ms, c_path, new_created): +        """ +        Get field from raw data + +        :param cls: current model +        :param attribute: attribute name +        :param data: current data dictionary +        :param m2ms: many to many list of tuple: (m2m key, m2m value) +        :param c_path: attribute path from the main model point of view +        :param new_created: dict of forced newly created items to prevent +        multiple creation +        :return: None +        """ +        if hasattr(cls, attribute) and \ +                getattr(getattr(cls, attribute), 'importer_trigger', None): +            # importer trigger +            self._set_importer_trigger(cls, attribute, data) +            return          try: -            field_object, model, direct, m2m = \ -                cls._meta.get_field_by_name(attribute) +            field_object = cls._meta.get_field(attribute)          except FieldDoesNotExist:              raise ImporterError(unicode(                  _(u"Importer configuration error: field \"{}\" does not exist "                    u"for {}.")).format(attribute, cls._meta.verbose_name)) -        if m2m: -            many_values = data.pop(attribute) -            if hasattr(field_object, 'rel'): -                model = field_object.rel.to -            elif hasattr(field_object, 'to'): -                model = field_object.to -            elif hasattr(field_object, 'model'): -                model = field_object.model -            if type(many_values) not in (list, tuple): -                many_values = [many_values] -            for val in many_values: -                if val.__class__ == model: -                    m2ms.append((attribute, val)) -                elif val.__class__ != model and type(val) == dict: -                    vals = [] - -                    # contruct many dict for each values -                    default_dict = {} - -                    # # get default values -                    p = [attribute] -                    if c_path: -                        p = list(c_path) + p -                    p = tuple(p) -                    if p in self._defaults: -                        for k in self._defaults[p]: -                            default_dict[k] = self._defaults[p][k] -                    # # init with simple values that will be duplicated -                    for key in val.keys(): -                        if type(val[key]) not in (list, tuple): -                            default_dict[key] = val[key] -                    vals.append(default_dict.copy()) -                    # # manage multiple values -                    for key in val.keys(): -                        if type(val[key]) in (list, tuple): -                            for idx, v in enumerate(val[key]): -                                if len(vals) <= idx: -                                    vals.append(default_dict.copy()) -                                vals[idx][key] = v - -                    # check that m2m are not empty -                    notempty = False -                    for dct in vals: -                        for k in dct: -                            if dct[k] not in ("", None): -                                notempty = True -                                break -                    if not notempty: -                        continue - -                    field_names = model._meta.get_all_field_names() -                    for v in vals: -                        if 'history_modifier' in field_names: -                            if 'defaults' not in v: -                                v['defaults'] = {} -                            v['defaults']['history_modifier'] = \ -                                self.history_modifier -                        m2m_m2ms = [] -                        c_c_path = c_path[:] -                        for k in v.keys(): -                            if k not in field_names: -                                continue -                            self.get_field(model, k, v, m2m_m2ms, c_c_path, -                                           new_created) -                        if '__force_new' in v: -                            created = v.pop('__force_new') -                            key = u";".join([u"{}-{}".format(k, v[k]) -                                             for k in sorted(v.keys())]) -                            # only one forced creation -                            if attribute in new_created \ -                                    and key in new_created[attribute]: -                                continue -                            if attribute not in new_created: -                                new_created[attribute] = [] -                            new_created[attribute].append(key) -                            has_values = bool([1 for k in v if v[k]]) -                            if has_values: -                                if self.MODEL_CREATION_LIMIT and \ -                                        model not in self.MODEL_CREATION_LIMIT: -                                    raise self._get_improperly_conf_error(model) -                                v = model.objects.create(**v) -                            else: -                                continue -                        else: -                            v['defaults'] = v.get('defaults', {}) -                            extra_fields = {} -                            # "File" type is a temp object and can be different -                            # for the same filename - it must be treated -                            # separately -                            for field in model._meta.fields: -                                k = field.name -                                # attr_class is a FileField attribute -                                if hasattr(field, 'attr_class') and k in v: -                                    extra_fields[k] = v.pop(k) -                            if not self.MODEL_CREATION_LIMIT or \ -                                    model in self.MODEL_CREATION_LIMIT: -                                v, created = model.objects.get_or_create( -                                    **v) -                            else: -                                get_v = v.copy() -                                if 'defaults' in get_v: -                                    get_v.pop('defaults') -                                try: -                                    v = model.objects.get(**get_v) -                                except model.DoesNotExist: -                                    raise self._get_does_not_exist_in_db_error( -                                        model, get_v) -                            changed = False -                            for k in extra_fields.keys(): -                                if extra_fields[k]: -                                    changed = True -                                    setattr(v, k, extra_fields[k]) -                            if changed: -                                v.save() -                        for att, objs in m2m_m2ms: -                            if type(objs) not in (list, tuple): -                                objs = [objs] -                            for obj in objs: -                                getattr(v, att).add(obj) -                        if self.import_instance \ -                           and hasattr(v, 'imports') and created: -                            v.imports.add(self.import_instance) -                        m2ms.append((attribute, v)) -        elif hasattr(field_object, 'rel') and field_object.rel: -            if type(data[attribute]) == dict: -                # put history_modifier for every created item -                if 'history_modifier' in \ -                   field_object.rel.to._meta.get_all_field_names(): -                    data[attribute]['history_modifier'] = \ -                        self.history_modifier -                try: -                    c_path.append(attribute) -                    data[attribute], created = self.get_object( -                        field_object.rel.to, data[attribute].copy(), c_path) -                except ImporterError, msg: -                    self.errors.append((self.idx_line, None, msg)) -                    data[attribute] = None -            elif type(data[attribute]) == list: -                data[attribute] = data[attribute][0] +        if field_object.many_to_many: +            m2ms += self._get_field_m2m(attribute, data, c_path, +                                        new_created, field_object) +            return +        if not hasattr(field_object, 'rel') or not field_object.rel: +            return +        if type(data[attribute]) == list: +            # extract the first item from list +            # be careful if the list has more than one item this is arbitrary +            if len(data[attribute]) > 1: +                logger.warning( +                    'Import {}: {} has many when only one is expected. Get ' +                    'the first one but it is not OK!'.format( +                     self.import_instance, attribute)) +            data[attribute] = data[attribute][0] +            return +        if type(data[attribute]) != dict: +            # we treat only dict formated values +            return +        # put history_modifier for every created item +        if 'history_modifier' in get_all_field_names(field_object.rel.to): +            data[attribute]['history_modifier'] = \ +                self.history_modifier +        try: +            c_path.append(attribute) +            data[attribute], created = self.get_object( +                field_object.rel.to, data[attribute].copy(), c_path) +        except ImporterError, msg: +            self.errors.append((self.idx_line, None, msg)) +            data[attribute] = None      def get_object(self, cls, data, path=[]):          m2ms = []          if type(data) != dict: +            # if data is not a dict we don't know what to do              return data, False +          is_empty = not bool(              [k for k in data if k not in ('history_modifier', 'defaults')                  and data[k]])          if is_empty: +            # if no value, no creation              return None, False          c_path = path[:] @@ -1408,9 +1544,8 @@ class Importer(object):                      data.pop(attribute)                      continue                  if not data[attribute]: -                    field_object, model, direct, m2m = \ -                        cls._meta.get_field_by_name(attribute) -                    if m2m: +                    field_object = cls._meta.get_field(attribute) +                    if field_object.many_to_many:                          data.pop(attribute)                      continue                  if attribute != '__force_new': @@ -1436,7 +1571,7 @@ class Importer(object):              if type(create_dict[k]) == dict:                  create_dict.pop(k)              # File doesn't like deepcopy -            if type(create_dict[k]) == File: +            elif type(create_dict[k]) == File:                  create_dict[k] = copy.copy(data[k])          # default values | 
