diff options
Diffstat (limited to 'ishtar_common')
| -rw-r--r-- | ishtar_common/data_importer.py | 142 | ||||
| -rw-r--r-- | ishtar_common/models_common.py | 2 | ||||
| -rw-r--r-- | ishtar_common/models_imports.py | 13 | ||||
| -rw-r--r-- | ishtar_common/utils.py | 3 | 
4 files changed, 132 insertions, 28 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 64f478614..19d53b72a 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -1034,19 +1034,43 @@ class Importer(object):      def post_processing(self, idx_line, item):          # force django based post-processing for the item          item = item.__class__.objects.get(pk=item.pk) -        item.save() -        if hasattr(item, "RELATED_POST_PROCESS"): -            for related_key in item.RELATED_POST_PROCESS: -                for related in getattr(item, related_key).all(): -                    related.save()          for func, context, value in self._item_post_processing:              context["import_object"] = self.import_instance              try: -                getattr(item, func)(context, value) +                returned = getattr(item, func)(context, value) +                if returned: +                    for rel in returned: +                        self._add_to_post_save(rel.__class__, rel.pk, idx_line)              except ImporterError as msg:                  self.errors.append((idx_line, None, msg))          return item +    def post_import(self): +        related_list = {} +        for cls_pk, idx_line in self.post_save_items.items(): +            cls, pk = cls_pk +            # force django based post-processing for the item +            item = cls.objects.get(pk=pk) +            item.save() +            if hasattr(item, "RELATED_POST_PROCESS"): +                for related_key in item.RELATED_POST_PROCESS: +                    for related in getattr(item, related_key).all(): +                        k = (related.__class__, related.pk) +                        if k not in related_list: +                            related_list[k] = idx_line +            if hasattr(item, "fix"): +                # post save/m2m specific fix +                item.fix() +        for cls, pk in related_list.keys(): +            try: +                item = cls.objects.get(pk=pk) +                item.save() +                if hasattr(item, "fix"): +                    # post save/m2m specific fix +                    item.fix() +            except cls.DoesNotExist: +                pass +      def initialize(self, table, output="silent", choose_default=False, user=None):          """          copy vals in columns and initialize formaters @@ -1207,6 +1231,8 @@ class Importer(object):              )          self.errors = []          self.validity = [] +        # a dict with (cls, item.pk) key and import line number as value -> mostly used as an ordered set +        self.post_save_items = {}          self.number_imported = 0          idx_last_col = 0          # index of the last required column @@ -1253,10 +1279,27 @@ class Importer(object):                  results.append(self._line_processing(idx_line, line))              except ImporterError as msg:                  self.errors.append((idx_line, None, msg)) +        self.post_import()          for item in self.to_be_close:              item.close()          return results +    def _add_to_post_save(self, cls, pk, idx_line): +        post_save_k = (cls, pk) +        c_idx_line = idx_line +        if post_save_k in self.post_save_items: +            c_idx_line = self.post_save_items.pop( +                post_save_k +            )  # change the order +        self.post_save_items[post_save_k] = c_idx_line + +    def _create_item(self, cls, dct, idx_line): +        obj = cls(**dct) +        obj._no_post_save = True  # delayed at the end of the import +        obj.save() +        self._add_to_post_save(cls, obj.pk, idx_line) +        return obj +      def _line_processing(self, idx_line, line):          for item in self.to_be_close:              item.close() @@ -1322,7 +1365,7 @@ class Importer(object):          self.new_objects, self.updated_objects = [], []          self.ambiguous_objects, self.not_find_objects = [], [] -        obj, created = self.get_object(self.OBJECT_CLS, data) +        obj, created = self.get_object(self.OBJECT_CLS, data, idx_line=idx_line)          if self.simulate:              return data          # print(data) @@ -1342,7 +1385,9 @@ class Importer(object):          if not created and "defaults" in data:              for k in data["defaults"]:                  setattr(obj, k, data["defaults"][k]) +            obj._no_post_save = True              obj.save() +            self._add_to_post_save(obj.__class__, obj.pk, idx_line)          n = datetime.datetime.now()          logger.debug("* %s - Item saved" % (str(n - n2)))          n2 = n @@ -1372,13 +1417,29 @@ class Importer(object):                  ):                      raise self._get_improperly_conf_error(through_cls)                  created = data.pop("__force_new") -                t_obj = through_cls.objects.create(**data) +                new_data = data.copy() +                if "defaults" in data: +                    default = new_data.pop("defaults") +                    for k in default: +                        if k not in new_data: +                            new_data[k] = default[k] +                t_obj = self._create_item(through_cls, new_data, idx_line)              else:                  if (                      not self.MODEL_CREATION_LIMIT                      or through_cls in self.MODEL_CREATION_LIMIT                  ): -                    t_obj, created = through_cls.objects.get_or_create(**data) +                    new_data = data.copy() +                    if "defaults" in data: +                        default = new_data.pop("defaults") +                    q = through_cls.objects.filter(**new_data) +                    if q.count(): +                        t_obj = through_cls.objects.get(**new_data) +                    else: +                        for k in default: +                            if k not in new_data: +                                new_data[k] = default[k] +                        t_obj = self._create_item(through_cls, new_data, idx_line)                  else:                      get_data = data.copy()                      if "defaults" in get_data: @@ -1393,7 +1454,9 @@ class Importer(object):                  t_obj = t_obj.__class__.objects.get(pk=t_obj.pk)                  for k in data["defaults"]:                      setattr(t_obj, k, data["defaults"][k]) +                t_obj._no_post_save = True                  t_obj.save() +                self._add_to_post_save(t_obj.__class__, t_obj.pk, idx_line)              if self.import_instance and hasattr(t_obj, "imports") and created:                  t_obj.imports.add(self.import_instance)          if not obj: @@ -1575,7 +1638,9 @@ class Importer(object):                      )          c_row.append(" ; ".join([v for v in c_values])) -    def _get_field_m2m(self, attribute, data, c_path, new_created, field_object): +    def _get_field_m2m( +        self, attribute, data, c_path, new_created, field_object, idx_line=None +    ):          """          Manage and m2m field from raw data @@ -1659,7 +1724,9 @@ class Importer(object):                  for k in list(v.keys()):                      if k not in field_names:                          continue -                    self.get_field(model, k, v, m2m_m2ms, c_c_path, new_created) +                    self.get_field( +                        model, k, v, m2m_m2ms, c_c_path, new_created, idx_line=idx_line +                    )                  if "__force_new" in v:                      created = v.pop("__force_new")                      key = ";".join(["{}-{}".format(k, v[k]) for k in sorted(v.keys())]) @@ -1760,7 +1827,7 @@ class Importer(object):              )          data.pop(attribute) -    def get_field(self, cls, attribute, data, m2ms, c_path, new_created): +    def get_field(self, cls, attribute, data, m2ms, c_path, new_created, idx_line=None):          """          Get field from raw data @@ -1799,7 +1866,12 @@ class Importer(object):          if field_object.many_to_many:              try:                  m2ms += self._get_field_m2m( -                    attribute, data, c_path, new_created, field_object +                    attribute, +                    data, +                    c_path, +                    new_created, +                    field_object, +                    idx_line=idx_line,                  )              except Exception as e:                  self.errors.append((self.idx_line, None, str(e))) @@ -1827,13 +1899,13 @@ class Importer(object):          try:              c_path.append(attribute)              data[attribute], created = self.get_object( -                field_object.rel.to, data[attribute].copy(), c_path +                field_object.rel.to, data[attribute].copy(), c_path, idx_line=idx_line              )          except ImporterError as msg:              self.errors.append((self.idx_line, None, msg))              data[attribute] = None -    def get_object(self, cls, data, path=None): +    def get_object(self, cls, data, path=None, idx_line=None):          if not path:              path = []          m2ms = [] @@ -1871,7 +1943,15 @@ class Importer(object):                          data.pop(attribute)                      continue                  if attribute != "__force_new": -                    self.get_field(cls, attribute, data, m2ms, c_c_path, new_created) +                    self.get_field( +                        cls, +                        attribute, +                        data, +                        m2ms, +                        c_c_path, +                        new_created, +                        idx_line=idx_line, +                    )          except (ValueError, IntegrityError, FieldDoesNotExist) as e:              try:                  message = str(e) @@ -1951,7 +2031,7 @@ class Importer(object):                      ):                          raise self._get_improperly_conf_error(cls)                      if not self.simulate: -                        obj = cls.objects.create(**new_dct) +                        self._create_item(cls, new_dct, idx_line)                      else:                          self.new_objects.append((path, cls, new_dct))                  else: @@ -2008,14 +2088,26 @@ class Importer(object):                                  not self.MODEL_CREATION_LIMIT                                  or cls in self.MODEL_CREATION_LIMIT                              ): -                                dct["defaults"] = defaults.copy() -                                obj, created = cls.objects.get_or_create(**dct) +                                q = cls.objects.filter(**dct) +                                if q.count(): +                                    obj = cls.objects.get(**dct) +                                else: +                                    created = True +                                    new_dct = dct.copy() +                                    for k in defaults: +                                        if k not in dct: +                                            new_dct[k] = defaults[k] +                                    obj = self._create_item(cls, new_dct, idx_line)                              else:                                  try:                                      obj = cls.objects.get(**dct) -                                    dct["defaults"] = defaults.copy() +                                    obj._no_post_save = ( +                                        True  # delayed at the end of the import +                                    ) +                                    self._add_to_post_save(cls, obj.pk, idx_line)                                  except cls.DoesNotExist:                                      raise self._get_does_not_exist_in_db_error(cls, dct) +                            dct["defaults"] = defaults.copy()                      if not created and not path and self.UNICITY_KEYS:                          updated_dct = {} @@ -2112,6 +2204,8 @@ class Importer(object):                              getattr(obj, attr).add(v)                          # force post save script                          v = v.__class__.objects.get(pk=v.pk) +                        self._add_to_post_save(v.__class__, v.pk, idx_line) +                        v._no_post_save = True                          try:                              v.save()                          except DatabaseError as e: @@ -2135,13 +2229,6 @@ class Importer(object):                      # defaults are not presented as matching data                      dct.pop("defaults")                      return self.updated_objects[-1][1], False - -            if m2ms: -                # force post save script -                obj.save() -            if hasattr(obj, "fix"): -                # post save/m2m specific fix -                obj.fix()          except IntegrityError as e:              try:                  message = str(e) @@ -2159,6 +2246,7 @@ class Importer(object):                      message                  )              ) +        self._add_to_post_save(obj.__class__, obj.pk, idx_line)          return obj, created      def _format_csv_line(self, values, empty="-"): diff --git a/ishtar_common/models_common.py b/ishtar_common/models_common.py index 08cfabdab..b34445f36 100644 --- a/ishtar_common/models_common.py +++ b/ishtar_common/models_common.py @@ -1058,7 +1058,7 @@ class JsonData(models.Model, CachedGen):          if not current_keys:              return          for keys in current_keys: -            if keys[0] == "__get_dynamic_choices": +            if keys and keys[0] == "__get_dynamic_choices":                  cls._get_dynamic_choices(keys[1], force=True)      @classmethod diff --git a/ishtar_common/models_imports.py b/ishtar_common/models_imports.py index 694b0d17d..c91ce473f 100644 --- a/ishtar_common/models_imports.py +++ b/ishtar_common/models_imports.py @@ -1486,8 +1486,21 @@ def pre_delete_import(sender, **kwargs):      to_delete = []      for accessor, imported in instance.get_all_imported():          to_delete.append(imported) +    post_delete_to_update = {}      for item in to_delete: +        if hasattr(item, "post_delete_to_update"): +            item._no_pre_delete = True +            for klass, values in item.post_delete_to_update(): +                if klass not in post_delete_to_update: +                    post_delete_to_update[klass] = set(values) +                else: +                    post_delete_to_update[klass].update(values)          item.delete() +    for klass in post_delete_to_update: +        for item_id in post_delete_to_update[klass]: +            q = klass.objects.filter(pk=item_id) +            if q.count(): +                q.all()[0].save()  pre_delete.connect(pre_delete_import, sender=Import) diff --git a/ishtar_common/utils.py b/ishtar_common/utils.py index 3d3131570..7173a2c46 100644 --- a/ishtar_common/utils.py +++ b/ishtar_common/utils.py @@ -475,6 +475,9 @@ EXTRA_KWARGS_TRIGGER = [  def cached_label_and_geo_changed(sender, **kwargs): +    instance = kwargs["instance"] +    if getattr(instance, "_no_post_save", False): +        return      cached_label_changed(sender=sender, **kwargs)      post_save_geo(sender=sender, **kwargs) | 
