diff options
author | Étienne Loks <etienne.loks@iggdrasil.net> | 2020-06-09 12:05:37 +0200 |
---|---|---|
committer | Étienne Loks <etienne.loks@iggdrasil.net> | 2021-02-28 12:15:20 +0100 |
commit | cbe8d22ccecad9e09dd70f5fdbbd78bf497671a9 (patch) | |
tree | d59806d5799ac74b9134183fba5a0ecf8b762d65 | |
parent | 8d908f14a8cd15734a522c17c144efe06e21416a (diff) | |
download | Ishtar-cbe8d22ccecad9e09dd70f5fdbbd78bf497671a9.tar.bz2 Ishtar-cbe8d22ccecad9e09dd70f5fdbbd78bf497671a9.zip |
Fix container merge
-rw-r--r-- | archaeological_warehouse/forms.py | 2 | ||||
-rw-r--r-- | archaeological_warehouse/models.py | 39 | ||||
-rw-r--r-- | archaeological_warehouse/tests.py | 180 | ||||
-rw-r--r-- | ishtar_common/model_merging.py | 28 | ||||
-rw-r--r-- | ishtar_common/models.py | 5 | ||||
-rw-r--r-- | ishtar_common/tests.py | 6 |
6 files changed, 173 insertions, 87 deletions
diff --git a/archaeological_warehouse/forms.py b/archaeological_warehouse/forms.py index fb345032e..cadfaaee1 100644 --- a/archaeological_warehouse/forms.py +++ b/archaeological_warehouse/forms.py @@ -298,7 +298,7 @@ class ContainerForm(CustomForm, ManageOldType, forms.Form): kwargs.pop('limits') super(ContainerForm, self).__init__(*args, **kwargs) - def _clean_parent(self): + def clean_parent(self): if not self.cleaned_data.get("parent", None): return warehouse_id = self.cleaned_data.get("location") diff --git a/archaeological_warehouse/models.py b/archaeological_warehouse/models.py index 5cadc46da..6341db8b8 100644 --- a/archaeological_warehouse/models.py +++ b/archaeological_warehouse/models.py @@ -380,6 +380,7 @@ class WarehouseDivisionLinkManager(models.Manager): container_type__txt_idx=container_type) + class ContainerType(GeneralType): stationary = models.BooleanField( _("Stationary"), default=False, @@ -426,6 +427,7 @@ class WarehouseDivisionLink(models.Model): return self.warehouse.uuid, self.container_type.txt_idx + class ContainerTree(models.Model): CREATE_SQL = """ CREATE VIEW containers_tree AS @@ -739,6 +741,7 @@ class Container(DocumentItem, Merge, LightHistorizedItem, QRCodeItem, GeoItem, DISABLE_POLYGONS = False MERGE_ATTRIBUTE = "get_cached_division" + MERGE_STRING_FIELDS = ["old_reference"] class Meta: verbose_name = _("Container") @@ -810,19 +813,24 @@ class Container(DocumentItem, Merge, LightHistorizedItem, QRCodeItem, GeoItem, def _get_base_image_path(self): return self.location._get_base_image_path() + "/" + self.external_id - def merge(self, item, keep_old=False): - # TODO: change localisation management - locas = [ - cl.division.division.txt_idx - for cl in ContainerLocalisation.objects.filter(container=self).all() - ] - for loca in ContainerLocalisation.objects.filter(container=item).all(): - if loca.division.division.txt_idx not in locas: - loca.container = self - loca.save() + def merge(self, item, keep_old=False, exclude_fields=None): + # merge child containers + child_references = {} + for child in Container.objects.filter(parent=self).all(): + key = (child.reference, child.container_type.pk) + child_references[key] = child + for child in Container.objects.filter(parent=item).all(): + key = (child.reference, child.container_type.pk) + if key in child_references.keys(): + # parent field can cause integrity error before the end of + # the merge + child_references[key].index = None + child_references[key].merge(child, exclude_fields=["parent"]) else: - loca.delete() - super(Container, self).merge(item, keep_old=keep_old) + child.parent = self + child.save() + super(Container, self).merge(item, keep_old=keep_old, + exclude_fields=exclude_fields) @classmethod def get_query_owns(cls, ishtaruser): @@ -1176,10 +1184,10 @@ m2m_changed.connect(document_attached_changed, class ContainerLocalisationManager(models.Manager): #TODO: to be deleted.... - def get_by_natural_key(self, container, warehouse, division): + def get_by_natural_key(self, container, warehouse, container_type): return self.get(container__uuid=container, division__warehouse__uuid=warehouse, - division__division__txt_idx=division) + division__container_type__txt_idx=container_type) class ContainerLocalisation(models.Model): @@ -1190,6 +1198,7 @@ class ContainerLocalisation(models.Model): verbose_name=_("Division")) reference = models.CharField(_("Reference"), max_length=200, default='') objects = ContainerLocalisationManager() + TO_BE_DELETED = True class Meta: verbose_name = _("Container localisation") @@ -1204,7 +1213,7 @@ class ContainerLocalisation(models.Model): def natural_key(self): return self.container.uuid, self.division.warehouse.uuid,\ - self.division.division.txt_idx + self.division.container_type.txt_idx def save(self, *args, **kwargs): super(ContainerLocalisation, self).save(*args, **kwargs) diff --git a/archaeological_warehouse/tests.py b/archaeological_warehouse/tests.py index 2fc492caa..a76c2763f 100644 --- a/archaeological_warehouse/tests.py +++ b/archaeological_warehouse/tests.py @@ -119,46 +119,46 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase): f2.container = c5 f2.container_ref = c5 f2.save() - wd1 = models.WarehouseDivision.objects.create( - label="Étagère", txt_idx="etagere" - ) - wd2 = models.WarehouseDivision.objects.create( - label="Allée", txt_idx="allee" - ) + ct1 = models.ContainerType.objects.all()[0] + ct2 = models.ContainerType.objects.all()[1] wl1 = models.WarehouseDivisionLink.objects.create( warehouse=w1, - division=wd1 + container_type=ct1 ) wl2 = models.WarehouseDivisionLink.objects.create( warehouse=w1, - division=wd2 + container_type=ct2 ) wl3 = models.WarehouseDivisionLink.objects.create( warehouse=w2, - division=wd2 + container_type=ct2 ) wl4 = models.WarehouseDivisionLink.objects.create( warehouse=w3, - division=wd1 + container_type=ct1 ) - models.ContainerLocalisation.objects.create( - container=c1, - division=wl1, - reference="A1" + models.Container.objects.create( + location=c1.location, + parent=c1, + container_type=ct1, + reference="A1", ) - models.ContainerLocalisation.objects.create( - container=c1, - division=wl2, + models.Container.objects.create( + location=c1.location, + parent=c1, + container_type=ct2, reference="A2" ) - models.ContainerLocalisation.objects.create( - container=c2, - division=wl3, + models.Container.objects.create( + location=c2.location, + parent=c2, + container_type=ct2, reference="A4" ) - models.ContainerLocalisation.objects.create( - container=c3, - division=wl4, + models.Container.objects.create( + location=c3.location, + parent=c3, + container_type=ct1, reference="A5" ) @@ -172,17 +172,12 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase): container_json = json.loads( res[('warehouse', "archaeological_warehouse__Container")] ) - self.assertEqual(len(container_json), 5) + self.assertEqual(len(container_json), 5 + 4) div_json = json.loads( res[('warehouse', "archaeological_warehouse__WarehouseDivisionLink")] ) self.assertEqual(len(div_json), 4) - loca_json = json.loads( - res[('warehouse', - "archaeological_warehouse__ContainerLocalisation")] - ) - self.assertEqual(len(loca_json), 4) result_queryset = Operation.objects.filter(uuid=self.operations[0].uuid) res = self.generic_serialization_test( @@ -202,11 +197,6 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase): "archaeological_warehouse__WarehouseDivisionLink")] ) self.assertEqual(len(div_json), 3) - loca_json = json.loads( - res[('warehouse', - "archaeological_warehouse__ContainerLocalisation")] - ) - self.assertEqual(len(loca_json), 3) result_queryset = ContextRecord.objects.filter( uuid=self.context_records[0].uuid) @@ -227,11 +217,6 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase): "archaeological_warehouse__WarehouseDivisionLink")] ) self.assertEqual(len(div_json), 3) - loca_json = json.loads( - res[('warehouse', - "archaeological_warehouse__ContainerLocalisation")] - ) - self.assertEqual(len(loca_json), 3) result_queryset = Find.objects.filter(uuid=self.finds[0].uuid) res = self.generic_serialization_test( @@ -251,11 +236,6 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase): "archaeological_warehouse__WarehouseDivisionLink")] ) self.assertEqual(len(div_json), 3) - loca_json = json.loads( - res[('warehouse', - "archaeological_warehouse__ContainerLocalisation")] - ) - self.assertEqual(len(loca_json), 3) result_queryset = models.Warehouse.objects.filter( id=self.warehouses[0].id) @@ -270,17 +250,12 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase): container_json = json.loads( res[('warehouse', "archaeological_warehouse__Container")] ) - self.assertEqual(len(container_json), 3) + self.assertEqual(len(container_json), 3 + 3) div_json = json.loads( res[('warehouse', "archaeological_warehouse__WarehouseDivisionLink")] ) self.assertEqual(len(div_json), 2) - loca_json = json.loads( - res[('warehouse', - "archaeological_warehouse__ContainerLocalisation")] - ) - self.assertEqual(len(loca_json), 2) def test_ope_serialization_with_warehouse_filter(self): res = self.generic_serialization_test( @@ -383,7 +358,7 @@ class WarehouseWizardCreationTest(WizardTest, FindInit, TestCase): }, 'divisions-warehouse_creation': [ { - 'division': None, + 'container_type': None, 'order': 42 } ] @@ -410,8 +385,8 @@ class WarehouseWizardCreationTest(WizardTest, FindInit, TestCase): models.WarehouseType.objects.all()[0].pk alt_data['warehouse-warehouse_creation']['warehouse_type'] = \ models.WarehouseType.objects.all()[0].pk - main_data['divisions-warehouse_creation'][0]['division'] = \ - models.WarehouseDivision.create_default_for_test()[0].pk + main_data['divisions-warehouse_creation'][0]['container_type'] = \ + models.ContainerType.objects.all()[0].pk self.warehouse_number = models.Warehouse.objects.count() self.warehouse_div_link = models.WarehouseDivisionLink.objects.count() super(WarehouseWizardCreationTest, self).pre_wizard() @@ -453,12 +428,13 @@ class ContainerWizardCreationTest(WizardTest, FindInit, TestCase): }, ), FormData( - 'Container creation with location', + 'Container creation', form_datas={ 'container-container_creation': { 'reference': 'hop-ref3', 'container_type': None, 'location': None, + 'parent': None, }, }, ), @@ -482,24 +458,19 @@ class ContainerWizardCreationTest(WizardTest, FindInit, TestCase): name="Alt", warehouse_type=models.WarehouseType.objects.all()[0] ) - div = models.WarehouseDivision.objects.create(label='division') - div_link = models.WarehouseDivisionLink.objects.create( - warehouse=alt_warehouse, division=div) alt_data['container-container_creation']["location"] = alt_warehouse.pk - alt_data['localisation-container_creation'] = { - 'division_{}'.format(div_link.pk): 'Combien ?' - } - + alt_data['container-container_creation']["parent"] = \ + models.Container.objects.create( + reference="Plop", + container_type=models.ContainerType.objects.all()[1], + location=alt_warehouse + ).pk self.container_number = models.Container.objects.count() - self.localisation_detail_number = \ - models.ContainerLocalisation.objects.count() super(ContainerWizardCreationTest, self).pre_wizard() def post_wizard(self): self.assertEqual(models.Container.objects.count(), self.container_number + 3) - self.assertEqual(models.ContainerLocalisation.objects.count(), - self.localisation_detail_number + 1) class WarehouseTest(TestCase): @@ -773,3 +744,82 @@ class ContainerTest(FindInit, TestCase): find_lst = [f for f in container_1.finds.all()] for f in [find0, find1, find2]: self.assertIn(f, find_lst) + + def test_merge_included_containers(self): + ct = models.ContainerType.objects.all()[0] + ct2 = models.ContainerType.objects.all()[1] + ct3 = models.ContainerType.objects.all()[2] + self.create_finds() + self.create_finds() + self.create_finds() + self.create_finds() + find0 = self.finds[0] + find1 = self.finds[1] + find2 = self.finds[2] + find3 = self.finds[3] + + top_container_1 = models.Container.objects.create( + reference="Topref 1", + location=self.main_warehouse, + container_type=ct) + find0.container = top_container_1 + find0.container_ref = top_container_1 + find0.save() + + top_container_2 = models.Container.objects.create( + reference="Topref 2", + location=self.main_warehouse, + container_type=ct) + find1.container = top_container_2 + find1.container_ref = top_container_2 + find1.save() + + middle_container_1 = models.Container.objects.create( + reference="Middle ref", + location=self.main_warehouse, + parent=top_container_1, + container_type=ct2) + find2.container = middle_container_1 + find2.container_ref = middle_container_1 + find2.save() + + middle_container_2 = models.Container.objects.create( + reference="Middle ref", + location=self.main_warehouse, + parent=top_container_2, + container_type=ct2) + + bottom_container_3 = models.Container.objects.create( + reference="Bottom ref", + location=self.main_warehouse, + parent=middle_container_2, + container_type=ct3) + find3.container = bottom_container_3 + find3.container_ref = bottom_container_3 + find3.save() + + top_container_1.merge(top_container_2) + + find0 = Find.objects.get(pk=find0.pk) + self.assertEqual(find0.container, top_container_1) + find1 = Find.objects.get(pk=find1.pk) + self.assertEqual(find1.container, top_container_1) + q = models.Container.objects.filter(reference="Topref 2") + self.assertEqual(q.count(), 0) + q = models.Container.objects.filter(reference="Topref 1") + self.assertEqual(q.count(), 1) + top_ref = q.all()[0] + self.assertEqual(top_ref.finds.count(), 2) + + q = models.Container.objects.filter(reference="Middle ref") + self.assertEqual(q.count(), 1) + middle = q.all()[0] + self.assertEqual(middle.parent, top_ref) + + q = models.Container.objects.filter(reference="Bottom ref") + self.assertEqual(q.count(), 1) + bottom = q.all()[0] + self.assertEqual(bottom.parent, middle) + + + diff --git a/ishtar_common/model_merging.py b/ishtar_common/model_merging.py index 06f65378d..6b839a143 100644 --- a/ishtar_common/model_merging.py +++ b/ishtar_common/model_merging.py @@ -20,7 +20,8 @@ def get_models(): @transaction.atomic -def merge_model_objects(primary_object, alias_objects=None, keep_old=False): +def merge_model_objects(primary_object, alias_objects=None, keep_old=False, + exclude_fields=None): """ Use this function to merge model objects (i.e. Users, Organizations, etc.) and migrate all of the related fields from the alias objects to the @@ -34,9 +35,15 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False): """ if not alias_objects: alias_objects = [] + if not exclude_fields: + exclude_fields = [] MERGE_FIELDS = ('merge_candidate', 'merge_exclusion') + MERGE_STRING_FIELDS = [] + if getattr(primary_object, "MERGE_STRING_FIELDS", None): + MERGE_STRING_FIELDS = primary_object.MERGE_STRING_FIELDS + if not isinstance(alias_objects, list): alias_objects = [alias_objects] @@ -80,6 +87,11 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False): alias_varname = related_object.get_accessor_name() # The variable name on the related model. obj_varname = related_object.field.name + if obj_varname in exclude_fields: + continue + if getattr(related_object.field, "related_model", None) and \ + not related_object.related_model._meta.managed: + continue try: related_objects = getattr(alias_object, alias_varname) except ObjectDoesNotExist: @@ -128,7 +140,7 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False): many_to_many_objects = getattr(alias_object, alias_varname).all() if alias_varname in blank_local_fields: - blank_local_fields.pop(alias_varname) + blank_local_fields.remove(alias_varname) for obj in many_to_many_objects.all(): getattr(alias_object, alias_varname).remove(obj) getattr(primary_object, alias_varname).add(obj) @@ -142,9 +154,21 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False): alias_object) for generic_related_object in field.model.objects.filter( **filter_kwargs): + if field.name in exclude_fields: + continue setattr(generic_related_object, field.name, primary_object) generic_related_object.save() + for field_name in MERGE_STRING_FIELDS: + if getattr(primary_object, field_name) and \ + getattr(alias_object, field_name): + val = "{} ; {}".format( + getattr(primary_object, field_name), + getattr(alias_object, field_name)) + if field_name in exclude_fields: + continue + setattr(primary_object, field_name, val) + # Try to fill all missing values in primary object by values of # duplicates filled_up = set() diff --git a/ishtar_common/models.py b/ishtar_common/models.py index 9011c4638..a6cfcf697 100644 --- a/ishtar_common/models.py +++ b/ishtar_common/models.py @@ -4286,8 +4286,9 @@ class Merge(models.Model): for m in self.merge_exclusion.all(): m.delete() - def merge(self, item, keep_old=False): - merge_model_objects(self, item, keep_old=keep_old) + def merge(self, item, keep_old=False, exclude_fields=None): + merge_model_objects(self, item, keep_old=keep_old, + exclude_fields=exclude_fields) self.generate_merge_candidate() diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py index 122f29d92..6d112f2d1 100644 --- a/ishtar_common/tests.py +++ b/ishtar_common/tests.py @@ -682,6 +682,8 @@ class GenericSerializationTest: module = importlib.import_module(module_name + ".models") model = getattr(module, model_name) + if getattr(model, "TO_BE_DELETED", False): + continue current_count = model.objects.count() result = json.loads(json_result[key]) serialization_count = len(result) @@ -917,11 +919,11 @@ class SerializationTest(GenericSerializationTest, TestCase): ) wl1 = WarehouseDivisionLink.objects.create( warehouse=w1, - division=wd1 + container_type=ContainerType.objects.all()[0], ) wl2 = WarehouseDivisionLink.objects.create( warehouse=w2, - division=wd2 + container_type = ContainerType.objects.all()[1], ) ContainerLocalisation.objects.create( container=c1, |