summaryrefslogtreecommitdiff
path: root/ishtar_common/data_importer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ishtar_common/data_importer.py')
-rw-r--r--ishtar_common/data_importer.py142
1 files changed, 115 insertions, 27 deletions
diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py
index ea3054110..666e93f7f 100644
--- a/ishtar_common/data_importer.py
+++ b/ishtar_common/data_importer.py
@@ -1044,19 +1044,43 @@ class Importer(object):
def post_processing(self, idx_line, item):
# force django based post-processing for the item
item = item.__class__.objects.get(pk=item.pk)
- item.save()
- if hasattr(item, "RELATED_POST_PROCESS"):
- for related_key in item.RELATED_POST_PROCESS:
- for related in getattr(item, related_key).all():
- related.save()
for func, context, value in self._item_post_processing:
context["import_object"] = self.import_instance
try:
- getattr(item, func)(context, value)
+ returned = getattr(item, func)(context, value)
+ if returned:
+ for rel in returned:
+ self._add_to_post_save(rel.__class__, rel.pk, idx_line)
except ImporterError as msg:
self.errors.append((idx_line, None, msg))
return item
+ def post_import(self):
+ related_list = {}
+ for cls_pk, idx_line in self.post_save_items.items():
+ cls, pk = cls_pk
+ # force django based post-processing for the item
+ item = cls.objects.get(pk=pk)
+ item.save()
+ if hasattr(item, "RELATED_POST_PROCESS"):
+ for related_key in item.RELATED_POST_PROCESS:
+ for related in getattr(item, related_key).all():
+ k = (related.__class__, related.pk)
+ if k not in related_list:
+ related_list[k] = idx_line
+ if hasattr(item, "fix"):
+ # post save/m2m specific fix
+ item.fix()
+ for cls, pk in related_list.keys():
+ try:
+ item = cls.objects.get(pk=pk)
+ item.save()
+ if hasattr(item, "fix"):
+ # post save/m2m specific fix
+ item.fix()
+ except cls.DoesNotExist:
+ pass
+
def initialize(self, table, output="silent", choose_default=False, user=None):
"""
copy vals in columns and initialize formaters
@@ -1217,6 +1241,8 @@ class Importer(object):
)
self.errors = []
self.validity = []
+ # a dict with (cls, item.pk) key and import line number as value -> mostly used as an ordered set
+ self.post_save_items = {}
self.number_imported = 0
idx_last_col = 0
# index of the last required column
@@ -1263,10 +1289,27 @@ class Importer(object):
results.append(self._line_processing(idx_line, line))
except ImporterError as msg:
self.errors.append((idx_line, None, msg))
+ self.post_import()
for item in self.to_be_close:
item.close()
return results
+ def _add_to_post_save(self, cls, pk, idx_line):
+ post_save_k = (cls, pk)
+ c_idx_line = idx_line
+ if post_save_k in self.post_save_items:
+ c_idx_line = self.post_save_items.pop(
+ post_save_k
+ ) # change the order
+ self.post_save_items[post_save_k] = c_idx_line
+
+ def _create_item(self, cls, dct, idx_line):
+ obj = cls(**dct)
+ obj._no_post_save = True # delayed at the end of the import
+ obj.save()
+ self._add_to_post_save(cls, obj.pk, idx_line)
+ return obj
+
def _line_processing(self, idx_line, line):
for item in self.to_be_close:
item.close()
@@ -1332,7 +1375,7 @@ class Importer(object):
self.new_objects, self.updated_objects = [], []
self.ambiguous_objects, self.not_find_objects = [], []
- obj, created = self.get_object(self.OBJECT_CLS, data)
+ obj, created = self.get_object(self.OBJECT_CLS, data, idx_line=idx_line)
if self.simulate:
return data
# print(data)
@@ -1352,7 +1395,9 @@ class Importer(object):
if not created and "defaults" in data:
for k in data["defaults"]:
setattr(obj, k, data["defaults"][k])
+ obj._no_post_save = True
obj.save()
+ self._add_to_post_save(obj.__class__, obj.pk, idx_line)
n = datetime.datetime.now()
logger.debug("* %s - Item saved" % (str(n - n2)))
n2 = n
@@ -1382,13 +1427,29 @@ class Importer(object):
):
raise self._get_improperly_conf_error(through_cls)
created = data.pop("__force_new")
- t_obj = through_cls.objects.create(**data)
+ new_data = data.copy()
+ if "defaults" in data:
+ default = new_data.pop("defaults")
+ for k in default:
+ if k not in new_data:
+ new_data[k] = default[k]
+ t_obj = self._create_item(through_cls, new_data, idx_line)
else:
if (
not self.MODEL_CREATION_LIMIT
or through_cls in self.MODEL_CREATION_LIMIT
):
- t_obj, created = through_cls.objects.get_or_create(**data)
+ new_data = data.copy()
+ if "defaults" in data:
+ default = new_data.pop("defaults")
+ q = through_cls.objects.filter(**new_data)
+ if q.count():
+ t_obj = through_cls.objects.get(**new_data)
+ else:
+ for k in default:
+ if k not in new_data:
+ new_data[k] = default[k]
+ t_obj = self._create_item(through_cls, new_data, idx_line)
else:
get_data = data.copy()
if "defaults" in get_data:
@@ -1403,7 +1464,9 @@ class Importer(object):
t_obj = t_obj.__class__.objects.get(pk=t_obj.pk)
for k in data["defaults"]:
setattr(t_obj, k, data["defaults"][k])
+ t_obj._no_post_save = True
t_obj.save()
+ self._add_to_post_save(t_obj.__class__, t_obj.pk, idx_line)
if self.import_instance and hasattr(t_obj, "imports") and created:
t_obj.imports.add(self.import_instance)
if not obj:
@@ -1585,7 +1648,9 @@ class Importer(object):
)
c_row.append(" ; ".join([v for v in c_values]))
- def _get_field_m2m(self, attribute, data, c_path, new_created, field_object):
+ def _get_field_m2m(
+ self, attribute, data, c_path, new_created, field_object, idx_line=None
+ ):
"""
Manage and m2m field from raw data
@@ -1669,7 +1734,9 @@ class Importer(object):
for k in list(v.keys()):
if k not in field_names:
continue
- self.get_field(model, k, v, m2m_m2ms, c_c_path, new_created)
+ self.get_field(
+ model, k, v, m2m_m2ms, c_c_path, new_created, idx_line=idx_line
+ )
if "__force_new" in v:
created = v.pop("__force_new")
key = ";".join(["{}-{}".format(k, v[k]) for k in sorted(v.keys())])
@@ -1769,7 +1836,7 @@ class Importer(object):
)
data.pop(attribute)
- def get_field(self, cls, attribute, data, m2ms, c_path, new_created):
+ def get_field(self, cls, attribute, data, m2ms, c_path, new_created, idx_line=None):
"""
Get field from raw data
@@ -1808,7 +1875,12 @@ class Importer(object):
if field_object.many_to_many:
try:
m2ms += self._get_field_m2m(
- attribute, data, c_path, new_created, field_object
+ attribute,
+ data,
+ c_path,
+ new_created,
+ field_object,
+ idx_line=idx_line,
)
except Exception as e:
self.errors.append((self.idx_line, None, str(e)))
@@ -1836,13 +1908,13 @@ class Importer(object):
try:
c_path.append(attribute)
data[attribute], created = self.get_object(
- field_object.rel.to, data[attribute].copy(), c_path
+ field_object.rel.to, data[attribute].copy(), c_path, idx_line=idx_line
)
except ImporterError as msg:
self.errors.append((self.idx_line, None, msg))
data[attribute] = None
- def get_object(self, cls, data, path=None):
+ def get_object(self, cls, data, path=None, idx_line=None):
if not path:
path = []
m2ms = []
@@ -1880,7 +1952,15 @@ class Importer(object):
data.pop(attribute)
continue
if attribute != "__force_new":
- self.get_field(cls, attribute, data, m2ms, c_c_path, new_created)
+ self.get_field(
+ cls,
+ attribute,
+ data,
+ m2ms,
+ c_c_path,
+ new_created,
+ idx_line=idx_line,
+ )
except (ValueError, IntegrityError, FieldDoesNotExist) as e:
try:
message = str(e)
@@ -1959,7 +2039,7 @@ class Importer(object):
):
raise self._get_improperly_conf_error(cls)
if not self.simulate:
- obj = cls.objects.create(**new_dct)
+ self._create_item(cls, new_dct, idx_line)
else:
self.new_objects.append((path, cls, new_dct))
else:
@@ -2015,14 +2095,26 @@ class Importer(object):
not self.MODEL_CREATION_LIMIT
or cls in self.MODEL_CREATION_LIMIT
):
- dct["defaults"] = defaults.copy()
- obj, created = cls.objects.get_or_create(**dct)
+ q = cls.objects.filter(**dct)
+ if q.count():
+ obj = cls.objects.get(**dct)
+ else:
+ created = True
+ new_dct = dct.copy()
+ for k in defaults:
+ if k not in dct:
+ new_dct[k] = defaults[k]
+ obj = self._create_item(cls, new_dct, idx_line)
else:
try:
obj = cls.objects.get(**dct)
- dct["defaults"] = defaults.copy()
+ obj._no_post_save = (
+ True # delayed at the end of the import
+ )
+ self._add_to_post_save(cls, obj.pk, idx_line)
except cls.DoesNotExist:
raise self._get_does_not_exist_in_db_error(cls, dct)
+ dct["defaults"] = defaults.copy()
if not created and not path and self.UNICITY_KEYS:
updated_dct = {}
@@ -2110,6 +2202,8 @@ class Importer(object):
getattr(obj, attr).add(v)
# force post save script
v = v.__class__.objects.get(pk=v.pk)
+ self._add_to_post_save(v.__class__, v.pk, idx_line)
+ v._no_post_save = True
try:
v.save()
except DatabaseError as e:
@@ -2133,13 +2227,6 @@ class Importer(object):
# defaults are not presented as matching data
dct.pop("defaults")
return self.updated_objects[-1][1], False
-
- if m2ms:
- # force post save script
- obj.save()
- if hasattr(obj, "fix"):
- # post save/m2m specific fix
- obj.fix()
except IntegrityError as e:
try:
message = str(e)
@@ -2153,6 +2240,7 @@ class Importer(object):
"Erreur d'import %s %s, contexte : %s, erreur : %s"
% (str(cls), str("__".join(path)), str(data), message)
)
+ self._add_to_post_save(obj.__class__, obj.pk, idx_line)
return obj, created
def _format_csv_line(self, values, empty="-"):