summaryrefslogtreecommitdiff
path: root/ishtar_common
diff options
context:
space:
mode:
authorÉtienne Loks <etienne.loks@iggdrasil.net>2020-06-09 12:05:37 +0200
committerÉtienne Loks <etienne.loks@iggdrasil.net>2021-02-28 12:15:20 +0100
commit3c950402b984badbbd9bad5b4616342760f4fd90 (patch)
treed59806d5799ac74b9134183fba5a0ecf8b762d65 /ishtar_common
parent006a33f7be51d968aef0795cdf4f6f4fd1e5ba2f (diff)
downloadIshtar-3c950402b984badbbd9bad5b4616342760f4fd90.tar.bz2
Ishtar-3c950402b984badbbd9bad5b4616342760f4fd90.zip
Fix container merge
Diffstat (limited to 'ishtar_common')
-rw-r--r--ishtar_common/model_merging.py28
-rw-r--r--ishtar_common/models.py5
-rw-r--r--ishtar_common/tests.py6
3 files changed, 33 insertions, 6 deletions
diff --git a/ishtar_common/model_merging.py b/ishtar_common/model_merging.py
index 06f65378d..6b839a143 100644
--- a/ishtar_common/model_merging.py
+++ b/ishtar_common/model_merging.py
@@ -20,7 +20,8 @@ def get_models():
@transaction.atomic
-def merge_model_objects(primary_object, alias_objects=None, keep_old=False):
+def merge_model_objects(primary_object, alias_objects=None, keep_old=False,
+ exclude_fields=None):
"""
Use this function to merge model objects (i.e. Users, Organizations,
etc.) and migrate all of the related fields from the alias objects to the
@@ -34,9 +35,15 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):
"""
if not alias_objects:
alias_objects = []
+ if not exclude_fields:
+ exclude_fields = []
MERGE_FIELDS = ('merge_candidate', 'merge_exclusion')
+ MERGE_STRING_FIELDS = []
+ if getattr(primary_object, "MERGE_STRING_FIELDS", None):
+ MERGE_STRING_FIELDS = primary_object.MERGE_STRING_FIELDS
+
if not isinstance(alias_objects, list):
alias_objects = [alias_objects]
@@ -80,6 +87,11 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):
alias_varname = related_object.get_accessor_name()
# The variable name on the related model.
obj_varname = related_object.field.name
+ if obj_varname in exclude_fields:
+ continue
+ if getattr(related_object.field, "related_model", None) and \
+ not related_object.related_model._meta.managed:
+ continue
try:
related_objects = getattr(alias_object, alias_varname)
except ObjectDoesNotExist:
@@ -128,7 +140,7 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):
many_to_many_objects = getattr(alias_object, alias_varname).all()
if alias_varname in blank_local_fields:
- blank_local_fields.pop(alias_varname)
+ blank_local_fields.remove(alias_varname)
for obj in many_to_many_objects.all():
getattr(alias_object, alias_varname).remove(obj)
getattr(primary_object, alias_varname).add(obj)
@@ -142,9 +154,21 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):
alias_object)
for generic_related_object in field.model.objects.filter(
**filter_kwargs):
+ if field.name in exclude_fields:
+ continue
setattr(generic_related_object, field.name, primary_object)
generic_related_object.save()
+ for field_name in MERGE_STRING_FIELDS:
+ if getattr(primary_object, field_name) and \
+ getattr(alias_object, field_name):
+ val = "{} ; {}".format(
+ getattr(primary_object, field_name),
+ getattr(alias_object, field_name))
+ if field_name in exclude_fields:
+ continue
+ setattr(primary_object, field_name, val)
+
# Try to fill all missing values in primary object by values of
# duplicates
filled_up = set()
diff --git a/ishtar_common/models.py b/ishtar_common/models.py
index 9011c4638..a6cfcf697 100644
--- a/ishtar_common/models.py
+++ b/ishtar_common/models.py
@@ -4286,8 +4286,9 @@ class Merge(models.Model):
for m in self.merge_exclusion.all():
m.delete()
- def merge(self, item, keep_old=False):
- merge_model_objects(self, item, keep_old=keep_old)
+ def merge(self, item, keep_old=False, exclude_fields=None):
+ merge_model_objects(self, item, keep_old=keep_old,
+ exclude_fields=exclude_fields)
self.generate_merge_candidate()
diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py
index 122f29d92..6d112f2d1 100644
--- a/ishtar_common/tests.py
+++ b/ishtar_common/tests.py
@@ -682,6 +682,8 @@ class GenericSerializationTest:
module = importlib.import_module(module_name + ".models")
model = getattr(module, model_name)
+ if getattr(model, "TO_BE_DELETED", False):
+ continue
current_count = model.objects.count()
result = json.loads(json_result[key])
serialization_count = len(result)
@@ -917,11 +919,11 @@ class SerializationTest(GenericSerializationTest, TestCase):
)
wl1 = WarehouseDivisionLink.objects.create(
warehouse=w1,
- division=wd1
+ container_type=ContainerType.objects.all()[0],
)
wl2 = WarehouseDivisionLink.objects.create(
warehouse=w2,
- division=wd2
+ container_type = ContainerType.objects.all()[1],
)
ContainerLocalisation.objects.create(
container=c1,