diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
-rw-r--r-- | ishtar_common/data_importer.py | 142 |
1 files changed, 115 insertions, 27 deletions
diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 64f478614..19d53b72a 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -1034,19 +1034,43 @@ class Importer(object): def post_processing(self, idx_line, item): # force django based post-processing for the item item = item.__class__.objects.get(pk=item.pk) - item.save() - if hasattr(item, "RELATED_POST_PROCESS"): - for related_key in item.RELATED_POST_PROCESS: - for related in getattr(item, related_key).all(): - related.save() for func, context, value in self._item_post_processing: context["import_object"] = self.import_instance try: - getattr(item, func)(context, value) + returned = getattr(item, func)(context, value) + if returned: + for rel in returned: + self._add_to_post_save(rel.__class__, rel.pk, idx_line) except ImporterError as msg: self.errors.append((idx_line, None, msg)) return item + def post_import(self): + related_list = {} + for cls_pk, idx_line in self.post_save_items.items(): + cls, pk = cls_pk + # force django based post-processing for the item + item = cls.objects.get(pk=pk) + item.save() + if hasattr(item, "RELATED_POST_PROCESS"): + for related_key in item.RELATED_POST_PROCESS: + for related in getattr(item, related_key).all(): + k = (related.__class__, related.pk) + if k not in related_list: + related_list[k] = idx_line + if hasattr(item, "fix"): + # post save/m2m specific fix + item.fix() + for cls, pk in related_list.keys(): + try: + item = cls.objects.get(pk=pk) + item.save() + if hasattr(item, "fix"): + # post save/m2m specific fix + item.fix() + except cls.DoesNotExist: + pass + def initialize(self, table, output="silent", choose_default=False, user=None): """ copy vals in columns and initialize formaters @@ -1207,6 +1231,8 @@ class Importer(object): ) self.errors = [] self.validity = [] + # a dict with (cls, item.pk) key and import line number as value -> mostly used as an ordered set + self.post_save_items = {} self.number_imported = 0 idx_last_col = 0 # index of the last required column @@ -1253,10 +1279,27 @@ class Importer(object): results.append(self._line_processing(idx_line, line)) except ImporterError as msg: self.errors.append((idx_line, None, msg)) + self.post_import() for item in self.to_be_close: item.close() return results + def _add_to_post_save(self, cls, pk, idx_line): + post_save_k = (cls, pk) + c_idx_line = idx_line + if post_save_k in self.post_save_items: + c_idx_line = self.post_save_items.pop( + post_save_k + ) # change the order + self.post_save_items[post_save_k] = c_idx_line + + def _create_item(self, cls, dct, idx_line): + obj = cls(**dct) + obj._no_post_save = True # delayed at the end of the import + obj.save() + self._add_to_post_save(cls, obj.pk, idx_line) + return obj + def _line_processing(self, idx_line, line): for item in self.to_be_close: item.close() @@ -1322,7 +1365,7 @@ class Importer(object): self.new_objects, self.updated_objects = [], [] self.ambiguous_objects, self.not_find_objects = [], [] - obj, created = self.get_object(self.OBJECT_CLS, data) + obj, created = self.get_object(self.OBJECT_CLS, data, idx_line=idx_line) if self.simulate: return data # print(data) @@ -1342,7 +1385,9 @@ class Importer(object): if not created and "defaults" in data: for k in data["defaults"]: setattr(obj, k, data["defaults"][k]) + obj._no_post_save = True obj.save() + self._add_to_post_save(obj.__class__, obj.pk, idx_line) n = datetime.datetime.now() logger.debug("* %s - Item saved" % (str(n - n2))) n2 = n @@ -1372,13 +1417,29 @@ class Importer(object): ): raise self._get_improperly_conf_error(through_cls) created = data.pop("__force_new") - t_obj = through_cls.objects.create(**data) + new_data = data.copy() + if "defaults" in data: + default = new_data.pop("defaults") + for k in default: + if k not in new_data: + new_data[k] = default[k] + t_obj = self._create_item(through_cls, new_data, idx_line) else: if ( not self.MODEL_CREATION_LIMIT or through_cls in self.MODEL_CREATION_LIMIT ): - t_obj, created = through_cls.objects.get_or_create(**data) + new_data = data.copy() + if "defaults" in data: + default = new_data.pop("defaults") + q = through_cls.objects.filter(**new_data) + if q.count(): + t_obj = through_cls.objects.get(**new_data) + else: + for k in default: + if k not in new_data: + new_data[k] = default[k] + t_obj = self._create_item(through_cls, new_data, idx_line) else: get_data = data.copy() if "defaults" in get_data: @@ -1393,7 +1454,9 @@ class Importer(object): t_obj = t_obj.__class__.objects.get(pk=t_obj.pk) for k in data["defaults"]: setattr(t_obj, k, data["defaults"][k]) + t_obj._no_post_save = True t_obj.save() + self._add_to_post_save(t_obj.__class__, t_obj.pk, idx_line) if self.import_instance and hasattr(t_obj, "imports") and created: t_obj.imports.add(self.import_instance) if not obj: @@ -1575,7 +1638,9 @@ class Importer(object): ) c_row.append(" ; ".join([v for v in c_values])) - def _get_field_m2m(self, attribute, data, c_path, new_created, field_object): + def _get_field_m2m( + self, attribute, data, c_path, new_created, field_object, idx_line=None + ): """ Manage and m2m field from raw data @@ -1659,7 +1724,9 @@ class Importer(object): for k in list(v.keys()): if k not in field_names: continue - self.get_field(model, k, v, m2m_m2ms, c_c_path, new_created) + self.get_field( + model, k, v, m2m_m2ms, c_c_path, new_created, idx_line=idx_line + ) if "__force_new" in v: created = v.pop("__force_new") key = ";".join(["{}-{}".format(k, v[k]) for k in sorted(v.keys())]) @@ -1760,7 +1827,7 @@ class Importer(object): ) data.pop(attribute) - def get_field(self, cls, attribute, data, m2ms, c_path, new_created): + def get_field(self, cls, attribute, data, m2ms, c_path, new_created, idx_line=None): """ Get field from raw data @@ -1799,7 +1866,12 @@ class Importer(object): if field_object.many_to_many: try: m2ms += self._get_field_m2m( - attribute, data, c_path, new_created, field_object + attribute, + data, + c_path, + new_created, + field_object, + idx_line=idx_line, ) except Exception as e: self.errors.append((self.idx_line, None, str(e))) @@ -1827,13 +1899,13 @@ class Importer(object): try: c_path.append(attribute) data[attribute], created = self.get_object( - field_object.rel.to, data[attribute].copy(), c_path + field_object.rel.to, data[attribute].copy(), c_path, idx_line=idx_line ) except ImporterError as msg: self.errors.append((self.idx_line, None, msg)) data[attribute] = None - def get_object(self, cls, data, path=None): + def get_object(self, cls, data, path=None, idx_line=None): if not path: path = [] m2ms = [] @@ -1871,7 +1943,15 @@ class Importer(object): data.pop(attribute) continue if attribute != "__force_new": - self.get_field(cls, attribute, data, m2ms, c_c_path, new_created) + self.get_field( + cls, + attribute, + data, + m2ms, + c_c_path, + new_created, + idx_line=idx_line, + ) except (ValueError, IntegrityError, FieldDoesNotExist) as e: try: message = str(e) @@ -1951,7 +2031,7 @@ class Importer(object): ): raise self._get_improperly_conf_error(cls) if not self.simulate: - obj = cls.objects.create(**new_dct) + self._create_item(cls, new_dct, idx_line) else: self.new_objects.append((path, cls, new_dct)) else: @@ -2008,14 +2088,26 @@ class Importer(object): not self.MODEL_CREATION_LIMIT or cls in self.MODEL_CREATION_LIMIT ): - dct["defaults"] = defaults.copy() - obj, created = cls.objects.get_or_create(**dct) + q = cls.objects.filter(**dct) + if q.count(): + obj = cls.objects.get(**dct) + else: + created = True + new_dct = dct.copy() + for k in defaults: + if k not in dct: + new_dct[k] = defaults[k] + obj = self._create_item(cls, new_dct, idx_line) else: try: obj = cls.objects.get(**dct) - dct["defaults"] = defaults.copy() + obj._no_post_save = ( + True # delayed at the end of the import + ) + self._add_to_post_save(cls, obj.pk, idx_line) except cls.DoesNotExist: raise self._get_does_not_exist_in_db_error(cls, dct) + dct["defaults"] = defaults.copy() if not created and not path and self.UNICITY_KEYS: updated_dct = {} @@ -2112,6 +2204,8 @@ class Importer(object): getattr(obj, attr).add(v) # force post save script v = v.__class__.objects.get(pk=v.pk) + self._add_to_post_save(v.__class__, v.pk, idx_line) + v._no_post_save = True try: v.save() except DatabaseError as e: @@ -2135,13 +2229,6 @@ class Importer(object): # defaults are not presented as matching data dct.pop("defaults") return self.updated_objects[-1][1], False - - if m2ms: - # force post save script - obj.save() - if hasattr(obj, "fix"): - # post save/m2m specific fix - obj.fix() except IntegrityError as e: try: message = str(e) @@ -2159,6 +2246,7 @@ class Importer(object): message ) ) + self._add_to_post_save(obj.__class__, obj.pk, idx_line) return obj, created def _format_csv_line(self, values, empty="-"): |