diff options
| author | Étienne Loks <etienne.loks@iggdrasil.net> | 2021-10-01 17:23:23 +0200 | 
|---|---|---|
| committer | Étienne Loks <etienne.loks@iggdrasil.net> | 2022-07-08 09:58:48 +0200 | 
| commit | 0f39694b279007da6d924515c3376962d80f378e (patch) | |
| tree | c507eadcd0c2629ea7dd18dbc75b5983f6a9039e | |
| parent | 2bdffb95a3a4ad83800ecfb9c5885350ed852f1f (diff) | |
| download | Ishtar-0f39694b279007da6d924515c3376962d80f378e.tar.bz2 Ishtar-0f39694b279007da6d924515c3376962d80f378e.zip | |
Performance on imports: group all post treatment to the end
| -rw-r--r-- | archaeological_warehouse/models.py | 48 | ||||
| -rw-r--r-- | archaeological_warehouse/tests.py | 8 | ||||
| -rw-r--r-- | ishtar_common/data_importer.py | 142 | ||||
| -rw-r--r-- | ishtar_common/models_common.py | 2 | ||||
| -rw-r--r-- | ishtar_common/models_imports.py | 13 | ||||
| -rw-r--r-- | ishtar_common/utils.py | 3 | 
6 files changed, 159 insertions, 57 deletions
| diff --git a/archaeological_warehouse/models.py b/archaeological_warehouse/models.py index de499d2b7..f84a2d75b 100644 --- a/archaeological_warehouse/models.py +++ b/archaeological_warehouse/models.py @@ -420,7 +420,7 @@ class Warehouse(      @post_importer_action      def add_localisations(self, context, value): -        self._add_localisations(context, value) +        return self._add_localisations(context, value)      add_localisations.post_save = True @@ -454,6 +454,7 @@ class Warehouse(          )          parent = None +        items = []          for idx, value in enumerate(values):              if idx >= len(divisions):                  if return_errors: @@ -475,6 +476,10 @@ class Warehouse(              )              if created and import_object:                  parent.imports.add(import_object) +            items.append(parent) +        if return_errors: +            return items, None +        return items      @property      def short_label(self): @@ -937,7 +942,7 @@ class Container(          "empty": SearchAltName(pgettext_lazy("key for text search", "empty"), "finds"),          "parent": SearchAltName(              pgettext_lazy("key for text search", "parent-container"), -            "parent__cached_label__iexact" +            "parent__cached_label__iexact",          ),          "contain_containers": SearchAltName(              pgettext_lazy("key for text search", "contain-containers"), @@ -1236,28 +1241,29 @@ class Container(          doc.container_ref_id = self.pk          doc.skip_history_when_saving = True          doc.save() +        return doc      @post_importer_action      def put_document_by_external_id(self, context, value): -        self.put_document_by_key(value, "external_id") +        return self.put_document_by_key(value, "external_id")      put_document_by_external_id.post_save = True      @post_importer_action      def put_document_by_reference(self, context, value): -        self.put_document_by_key(value, "reference") +        return self.put_document_by_key(value, "reference")      put_document_by_reference.post_save = True      @post_importer_action      def put_document_by_internal_reference(self, context, value): -        self.put_document_by_key(value, "internal_reference") +        return self.put_document_by_key(value, "internal_reference")      put_document_by_internal_reference.post_save = True      @post_importer_action      def put_document_by_complete_identifier(self, context, value): -        self.put_document_by_key(value, "complete_identifier") +        return self.put_document_by_key(value, "complete_identifier")      put_document_by_complete_identifier.post_save = True @@ -1750,9 +1756,7 @@ class Container(              return          q = Container.objects.filter(location=self.location, index__isnull=False)          self.index = ( -            int(q.all().aggregate(Max("index"))["index__max"]) + 1 -            if q.count() -            else 1 +            int(q.all().aggregate(Max("index"))["index__max"]) + 1 if q.count() else 1          )          if not self.cached_division:              self.cached_division = self._generate_cached_division() @@ -1767,6 +1771,11 @@ class Container(              self.location.max_division_number = number              self.location.save() +    def post_delete_to_update(self): +        q = Container.objects.filter(container_tree_child__container_parent=self) +        q.update(cached_division="") +        return ((self.__class__, q.values_list("id", flat=True)),) +      def save(self, *args, **kwargs):          self.pre_save()          super(Container, self).save(*args, **kwargs) @@ -1810,33 +1819,20 @@ class Container(  def container_post_save(sender, **kwargs):      cached_label_and_geo_changed(sender=sender, **kwargs) -    # TODO: to be deleted??? -    """ -    if not kwargs.get('instance'): -        return -    instance = kwargs.get('instance') -    for loca in ContainerLocalisation.objects.filter( -            container=instance).exclude( -            division__warehouse=instance.location).all(): -        q = WarehouseDivisionLink.objects.filter( -            warehouse=instance.location, -            division=loca.division.division -        ) -        if not q.count(): -            continue -        loca.division = q.all()[0] -        loca.save() -    """  def container_pre_delete(sender, **kwargs):      instance = kwargs["instance"] +    if getattr(instance, "_no_pre_delete", False): +        return      q = Container.objects.filter(container_tree_child__container_parent=instance)      q.update(cached_division="")  def container_post_delete(sender, **kwargs):      instance = kwargs["instance"] +    if getattr(instance, "_no_pre_delete", False): +        return      q = Container.objects.filter(cached_division="", location=instance.location)      for c in q.all():          c.save() diff --git a/archaeological_warehouse/tests.py b/archaeological_warehouse/tests.py index 4e6ca9dd9..7cd5af848 100644 --- a/archaeological_warehouse/tests.py +++ b/archaeological_warehouse/tests.py @@ -491,7 +491,9 @@ class WarehouseTest(TestCase):          )          self.assertTrue(error) -        error = self.warehouse._add_localisations(None, base_value, return_errors=True) +        __, error = self.warehouse._add_localisations( +            None, base_value, return_errors=True +        )          self.assertIsNone(error)          parent = None          for idx, reference in enumerate(("A", "42", "allée 3;2")): @@ -513,7 +515,7 @@ class WarehouseTest(TestCase):          self.assertEqual(container_nb + len(self.container_types), new_container_nb)          value = "A;42;allée 4" -        error = self.warehouse._add_localisations(None, value, return_errors=True) +        __, error = self.warehouse._add_localisations(None, value, return_errors=True)          self.assertIsNone(error)          # only create a new container          self.assertEqual(new_container_nb + 1, models.Container.objects.count()) @@ -529,7 +531,7 @@ class WarehouseTest(TestCase):          # test with an empty localisation          value = "A;42;;35" -        error = self.warehouse._add_localisations(None, value, return_errors=True) +        __, error = self.warehouse._add_localisations(None, value, return_errors=True)          self.assertIsNone(error)          q = models.Container.objects.filter(              parent__reference="42", diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 64f478614..19d53b72a 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -1034,19 +1034,43 @@ class Importer(object):      def post_processing(self, idx_line, item):          # force django based post-processing for the item          item = item.__class__.objects.get(pk=item.pk) -        item.save() -        if hasattr(item, "RELATED_POST_PROCESS"): -            for related_key in item.RELATED_POST_PROCESS: -                for related in getattr(item, related_key).all(): -                    related.save()          for func, context, value in self._item_post_processing:              context["import_object"] = self.import_instance              try: -                getattr(item, func)(context, value) +                returned = getattr(item, func)(context, value) +                if returned: +                    for rel in returned: +                        self._add_to_post_save(rel.__class__, rel.pk, idx_line)              except ImporterError as msg:                  self.errors.append((idx_line, None, msg))          return item +    def post_import(self): +        related_list = {} +        for cls_pk, idx_line in self.post_save_items.items(): +            cls, pk = cls_pk +            # force django based post-processing for the item +            item = cls.objects.get(pk=pk) +            item.save() +            if hasattr(item, "RELATED_POST_PROCESS"): +                for related_key in item.RELATED_POST_PROCESS: +                    for related in getattr(item, related_key).all(): +                        k = (related.__class__, related.pk) +                        if k not in related_list: +                            related_list[k] = idx_line +            if hasattr(item, "fix"): +                # post save/m2m specific fix +                item.fix() +        for cls, pk in related_list.keys(): +            try: +                item = cls.objects.get(pk=pk) +                item.save() +                if hasattr(item, "fix"): +                    # post save/m2m specific fix +                    item.fix() +            except cls.DoesNotExist: +                pass +      def initialize(self, table, output="silent", choose_default=False, user=None):          """          copy vals in columns and initialize formaters @@ -1207,6 +1231,8 @@ class Importer(object):              )          self.errors = []          self.validity = [] +        # a dict with (cls, item.pk) key and import line number as value -> mostly used as an ordered set +        self.post_save_items = {}          self.number_imported = 0          idx_last_col = 0          # index of the last required column @@ -1253,10 +1279,27 @@ class Importer(object):                  results.append(self._line_processing(idx_line, line))              except ImporterError as msg:                  self.errors.append((idx_line, None, msg)) +        self.post_import()          for item in self.to_be_close:              item.close()          return results +    def _add_to_post_save(self, cls, pk, idx_line): +        post_save_k = (cls, pk) +        c_idx_line = idx_line +        if post_save_k in self.post_save_items: +            c_idx_line = self.post_save_items.pop( +                post_save_k +            )  # change the order +        self.post_save_items[post_save_k] = c_idx_line + +    def _create_item(self, cls, dct, idx_line): +        obj = cls(**dct) +        obj._no_post_save = True  # delayed at the end of the import +        obj.save() +        self._add_to_post_save(cls, obj.pk, idx_line) +        return obj +      def _line_processing(self, idx_line, line):          for item in self.to_be_close:              item.close() @@ -1322,7 +1365,7 @@ class Importer(object):          self.new_objects, self.updated_objects = [], []          self.ambiguous_objects, self.not_find_objects = [], [] -        obj, created = self.get_object(self.OBJECT_CLS, data) +        obj, created = self.get_object(self.OBJECT_CLS, data, idx_line=idx_line)          if self.simulate:              return data          # print(data) @@ -1342,7 +1385,9 @@ class Importer(object):          if not created and "defaults" in data:              for k in data["defaults"]:                  setattr(obj, k, data["defaults"][k]) +            obj._no_post_save = True              obj.save() +            self._add_to_post_save(obj.__class__, obj.pk, idx_line)          n = datetime.datetime.now()          logger.debug("* %s - Item saved" % (str(n - n2)))          n2 = n @@ -1372,13 +1417,29 @@ class Importer(object):                  ):                      raise self._get_improperly_conf_error(through_cls)                  created = data.pop("__force_new") -                t_obj = through_cls.objects.create(**data) +                new_data = data.copy() +                if "defaults" in data: +                    default = new_data.pop("defaults") +                    for k in default: +                        if k not in new_data: +                            new_data[k] = default[k] +                t_obj = self._create_item(through_cls, new_data, idx_line)              else:                  if (                      not self.MODEL_CREATION_LIMIT                      or through_cls in self.MODEL_CREATION_LIMIT                  ): -                    t_obj, created = through_cls.objects.get_or_create(**data) +                    new_data = data.copy() +                    if "defaults" in data: +                        default = new_data.pop("defaults") +                    q = through_cls.objects.filter(**new_data) +                    if q.count(): +                        t_obj = through_cls.objects.get(**new_data) +                    else: +                        for k in default: +                            if k not in new_data: +                                new_data[k] = default[k] +                        t_obj = self._create_item(through_cls, new_data, idx_line)                  else:                      get_data = data.copy()                      if "defaults" in get_data: @@ -1393,7 +1454,9 @@ class Importer(object):                  t_obj = t_obj.__class__.objects.get(pk=t_obj.pk)                  for k in data["defaults"]:                      setattr(t_obj, k, data["defaults"][k]) +                t_obj._no_post_save = True                  t_obj.save() +                self._add_to_post_save(t_obj.__class__, t_obj.pk, idx_line)              if self.import_instance and hasattr(t_obj, "imports") and created:                  t_obj.imports.add(self.import_instance)          if not obj: @@ -1575,7 +1638,9 @@ class Importer(object):                      )          c_row.append(" ; ".join([v for v in c_values])) -    def _get_field_m2m(self, attribute, data, c_path, new_created, field_object): +    def _get_field_m2m( +        self, attribute, data, c_path, new_created, field_object, idx_line=None +    ):          """          Manage and m2m field from raw data @@ -1659,7 +1724,9 @@ class Importer(object):                  for k in list(v.keys()):                      if k not in field_names:                          continue -                    self.get_field(model, k, v, m2m_m2ms, c_c_path, new_created) +                    self.get_field( +                        model, k, v, m2m_m2ms, c_c_path, new_created, idx_line=idx_line +                    )                  if "__force_new" in v:                      created = v.pop("__force_new")                      key = ";".join(["{}-{}".format(k, v[k]) for k in sorted(v.keys())]) @@ -1760,7 +1827,7 @@ class Importer(object):              )          data.pop(attribute) -    def get_field(self, cls, attribute, data, m2ms, c_path, new_created): +    def get_field(self, cls, attribute, data, m2ms, c_path, new_created, idx_line=None):          """          Get field from raw data @@ -1799,7 +1866,12 @@ class Importer(object):          if field_object.many_to_many:              try:                  m2ms += self._get_field_m2m( -                    attribute, data, c_path, new_created, field_object +                    attribute, +                    data, +                    c_path, +                    new_created, +                    field_object, +                    idx_line=idx_line,                  )              except Exception as e:                  self.errors.append((self.idx_line, None, str(e))) @@ -1827,13 +1899,13 @@ class Importer(object):          try:              c_path.append(attribute)              data[attribute], created = self.get_object( -                field_object.rel.to, data[attribute].copy(), c_path +                field_object.rel.to, data[attribute].copy(), c_path, idx_line=idx_line              )          except ImporterError as msg:              self.errors.append((self.idx_line, None, msg))              data[attribute] = None -    def get_object(self, cls, data, path=None): +    def get_object(self, cls, data, path=None, idx_line=None):          if not path:              path = []          m2ms = [] @@ -1871,7 +1943,15 @@ class Importer(object):                          data.pop(attribute)                      continue                  if attribute != "__force_new": -                    self.get_field(cls, attribute, data, m2ms, c_c_path, new_created) +                    self.get_field( +                        cls, +                        attribute, +                        data, +                        m2ms, +                        c_c_path, +                        new_created, +                        idx_line=idx_line, +                    )          except (ValueError, IntegrityError, FieldDoesNotExist) as e:              try:                  message = str(e) @@ -1951,7 +2031,7 @@ class Importer(object):                      ):                          raise self._get_improperly_conf_error(cls)                      if not self.simulate: -                        obj = cls.objects.create(**new_dct) +                        self._create_item(cls, new_dct, idx_line)                      else:                          self.new_objects.append((path, cls, new_dct))                  else: @@ -2008,14 +2088,26 @@ class Importer(object):                                  not self.MODEL_CREATION_LIMIT                                  or cls in self.MODEL_CREATION_LIMIT                              ): -                                dct["defaults"] = defaults.copy() -                                obj, created = cls.objects.get_or_create(**dct) +                                q = cls.objects.filter(**dct) +                                if q.count(): +                                    obj = cls.objects.get(**dct) +                                else: +                                    created = True +                                    new_dct = dct.copy() +                                    for k in defaults: +                                        if k not in dct: +                                            new_dct[k] = defaults[k] +                                    obj = self._create_item(cls, new_dct, idx_line)                              else:                                  try:                                      obj = cls.objects.get(**dct) -                                    dct["defaults"] = defaults.copy() +                                    obj._no_post_save = ( +                                        True  # delayed at the end of the import +                                    ) +                                    self._add_to_post_save(cls, obj.pk, idx_line)                                  except cls.DoesNotExist:                                      raise self._get_does_not_exist_in_db_error(cls, dct) +                            dct["defaults"] = defaults.copy()                      if not created and not path and self.UNICITY_KEYS:                          updated_dct = {} @@ -2112,6 +2204,8 @@ class Importer(object):                              getattr(obj, attr).add(v)                          # force post save script                          v = v.__class__.objects.get(pk=v.pk) +                        self._add_to_post_save(v.__class__, v.pk, idx_line) +                        v._no_post_save = True                          try:                              v.save()                          except DatabaseError as e: @@ -2135,13 +2229,6 @@ class Importer(object):                      # defaults are not presented as matching data                      dct.pop("defaults")                      return self.updated_objects[-1][1], False - -            if m2ms: -                # force post save script -                obj.save() -            if hasattr(obj, "fix"): -                # post save/m2m specific fix -                obj.fix()          except IntegrityError as e:              try:                  message = str(e) @@ -2159,6 +2246,7 @@ class Importer(object):                      message                  )              ) +        self._add_to_post_save(obj.__class__, obj.pk, idx_line)          return obj, created      def _format_csv_line(self, values, empty="-"): diff --git a/ishtar_common/models_common.py b/ishtar_common/models_common.py index 08cfabdab..b34445f36 100644 --- a/ishtar_common/models_common.py +++ b/ishtar_common/models_common.py @@ -1058,7 +1058,7 @@ class JsonData(models.Model, CachedGen):          if not current_keys:              return          for keys in current_keys: -            if keys[0] == "__get_dynamic_choices": +            if keys and keys[0] == "__get_dynamic_choices":                  cls._get_dynamic_choices(keys[1], force=True)      @classmethod diff --git a/ishtar_common/models_imports.py b/ishtar_common/models_imports.py index 694b0d17d..c91ce473f 100644 --- a/ishtar_common/models_imports.py +++ b/ishtar_common/models_imports.py @@ -1486,8 +1486,21 @@ def pre_delete_import(sender, **kwargs):      to_delete = []      for accessor, imported in instance.get_all_imported():          to_delete.append(imported) +    post_delete_to_update = {}      for item in to_delete: +        if hasattr(item, "post_delete_to_update"): +            item._no_pre_delete = True +            for klass, values in item.post_delete_to_update(): +                if klass not in post_delete_to_update: +                    post_delete_to_update[klass] = set(values) +                else: +                    post_delete_to_update[klass].update(values)          item.delete() +    for klass in post_delete_to_update: +        for item_id in post_delete_to_update[klass]: +            q = klass.objects.filter(pk=item_id) +            if q.count(): +                q.all()[0].save()  pre_delete.connect(pre_delete_import, sender=Import) diff --git a/ishtar_common/utils.py b/ishtar_common/utils.py index 3d3131570..7173a2c46 100644 --- a/ishtar_common/utils.py +++ b/ishtar_common/utils.py @@ -475,6 +475,9 @@ EXTRA_KWARGS_TRIGGER = [  def cached_label_and_geo_changed(sender, **kwargs): +    instance = kwargs["instance"] +    if getattr(instance, "_no_post_save", False): +        return      cached_label_changed(sender=sender, **kwargs)      post_save_geo(sender=sender, **kwargs) | 
