diff options
| -rw-r--r-- | archaeological_warehouse/forms.py | 2 | ||||
| -rw-r--r-- | archaeological_warehouse/models.py | 39 | ||||
| -rw-r--r-- | archaeological_warehouse/tests.py | 180 | ||||
| -rw-r--r-- | ishtar_common/model_merging.py | 28 | ||||
| -rw-r--r-- | ishtar_common/models.py | 5 | ||||
| -rw-r--r-- | ishtar_common/tests.py | 6 | 
6 files changed, 173 insertions, 87 deletions
| diff --git a/archaeological_warehouse/forms.py b/archaeological_warehouse/forms.py index fb345032e..cadfaaee1 100644 --- a/archaeological_warehouse/forms.py +++ b/archaeological_warehouse/forms.py @@ -298,7 +298,7 @@ class ContainerForm(CustomForm, ManageOldType, forms.Form):              kwargs.pop('limits')          super(ContainerForm, self).__init__(*args, **kwargs) -    def _clean_parent(self): +    def clean_parent(self):          if not self.cleaned_data.get("parent", None):              return          warehouse_id = self.cleaned_data.get("location") diff --git a/archaeological_warehouse/models.py b/archaeological_warehouse/models.py index 5cadc46da..6341db8b8 100644 --- a/archaeological_warehouse/models.py +++ b/archaeological_warehouse/models.py @@ -380,6 +380,7 @@ class WarehouseDivisionLinkManager(models.Manager):                          container_type__txt_idx=container_type) +  class ContainerType(GeneralType):      stationary = models.BooleanField(          _("Stationary"), default=False, @@ -426,6 +427,7 @@ class WarehouseDivisionLink(models.Model):          return self.warehouse.uuid, self.container_type.txt_idx +  class ContainerTree(models.Model):      CREATE_SQL = """      CREATE VIEW containers_tree AS @@ -739,6 +741,7 @@ class Container(DocumentItem, Merge, LightHistorizedItem, QRCodeItem, GeoItem,      DISABLE_POLYGONS = False      MERGE_ATTRIBUTE = "get_cached_division" +    MERGE_STRING_FIELDS = ["old_reference"]      class Meta:          verbose_name = _("Container") @@ -810,19 +813,24 @@ class Container(DocumentItem, Merge, LightHistorizedItem, QRCodeItem, GeoItem,      def _get_base_image_path(self):          return self.location._get_base_image_path() + "/" + self.external_id -    def merge(self, item, keep_old=False): -        # TODO: change localisation management -        locas = [ -            cl.division.division.txt_idx -            for cl in ContainerLocalisation.objects.filter(container=self).all() -        ] -        for loca in ContainerLocalisation.objects.filter(container=item).all(): -            if loca.division.division.txt_idx not in locas: -                loca.container = self -                loca.save() +    def merge(self, item, keep_old=False, exclude_fields=None): +        # merge child containers +        child_references = {} +        for child in Container.objects.filter(parent=self).all(): +            key = (child.reference, child.container_type.pk) +            child_references[key] = child +        for child in Container.objects.filter(parent=item).all(): +            key = (child.reference, child.container_type.pk) +            if key in child_references.keys(): +                # parent field can cause integrity error before the end of +                # the merge +                child_references[key].index = None +                child_references[key].merge(child, exclude_fields=["parent"])              else: -                loca.delete() -        super(Container, self).merge(item, keep_old=keep_old) +                child.parent = self +                child.save() +        super(Container, self).merge(item, keep_old=keep_old, +                                     exclude_fields=exclude_fields)      @classmethod      def get_query_owns(cls, ishtaruser): @@ -1176,10 +1184,10 @@ m2m_changed.connect(document_attached_changed,  class ContainerLocalisationManager(models.Manager):      #TODO: to be deleted.... -    def get_by_natural_key(self, container, warehouse, division): +    def get_by_natural_key(self, container, warehouse, container_type):          return self.get(container__uuid=container,                          division__warehouse__uuid=warehouse, -                        division__division__txt_idx=division) +                        division__container_type__txt_idx=container_type)  class ContainerLocalisation(models.Model): @@ -1190,6 +1198,7 @@ class ContainerLocalisation(models.Model):                                   verbose_name=_("Division"))      reference = models.CharField(_("Reference"), max_length=200, default='')      objects = ContainerLocalisationManager() +    TO_BE_DELETED = True      class Meta:          verbose_name = _("Container localisation") @@ -1204,7 +1213,7 @@ class ContainerLocalisation(models.Model):      def natural_key(self):          return self.container.uuid, self.division.warehouse.uuid,\ -               self.division.division.txt_idx +               self.division.container_type.txt_idx      def save(self, *args, **kwargs):          super(ContainerLocalisation, self).save(*args, **kwargs) diff --git a/archaeological_warehouse/tests.py b/archaeological_warehouse/tests.py index 2fc492caa..a76c2763f 100644 --- a/archaeological_warehouse/tests.py +++ b/archaeological_warehouse/tests.py @@ -119,46 +119,46 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase):          f2.container = c5          f2.container_ref = c5          f2.save() -        wd1 = models.WarehouseDivision.objects.create( -            label="Étagère", txt_idx="etagere" -        ) -        wd2 = models.WarehouseDivision.objects.create( -            label="Allée", txt_idx="allee" -        ) +        ct1 = models.ContainerType.objects.all()[0] +        ct2 = models.ContainerType.objects.all()[1]          wl1 = models.WarehouseDivisionLink.objects.create(              warehouse=w1, -            division=wd1 +            container_type=ct1          )          wl2 = models.WarehouseDivisionLink.objects.create(              warehouse=w1, -            division=wd2 +            container_type=ct2          )          wl3 = models.WarehouseDivisionLink.objects.create(              warehouse=w2, -            division=wd2 +            container_type=ct2          )          wl4 = models.WarehouseDivisionLink.objects.create(              warehouse=w3, -            division=wd1 +            container_type=ct1          ) -        models.ContainerLocalisation.objects.create( -            container=c1, -            division=wl1, -            reference="A1" +        models.Container.objects.create( +            location=c1.location, +            parent=c1, +            container_type=ct1, +            reference="A1",          ) -        models.ContainerLocalisation.objects.create( -            container=c1, -            division=wl2, +        models.Container.objects.create( +            location=c1.location, +            parent=c1, +            container_type=ct2,              reference="A2"          ) -        models.ContainerLocalisation.objects.create( -            container=c2, -            division=wl3, +        models.Container.objects.create( +            location=c2.location, +            parent=c2, +            container_type=ct2,              reference="A4"          ) -        models.ContainerLocalisation.objects.create( -            container=c3, -            division=wl4, +        models.Container.objects.create( +            location=c3.location, +            parent=c3, +            container_type=ct1,              reference="A5"          ) @@ -172,17 +172,12 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase):          container_json = json.loads(              res[('warehouse', "archaeological_warehouse__Container")]          ) -        self.assertEqual(len(container_json), 5) +        self.assertEqual(len(container_json), 5 + 4)          div_json = json.loads(              res[('warehouse',                   "archaeological_warehouse__WarehouseDivisionLink")]          )          self.assertEqual(len(div_json), 4) -        loca_json = json.loads( -            res[('warehouse', -                 "archaeological_warehouse__ContainerLocalisation")] -        ) -        self.assertEqual(len(loca_json), 4)          result_queryset = Operation.objects.filter(uuid=self.operations[0].uuid)          res = self.generic_serialization_test( @@ -202,11 +197,6 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase):                   "archaeological_warehouse__WarehouseDivisionLink")]          )          self.assertEqual(len(div_json), 3) -        loca_json = json.loads( -            res[('warehouse', -                 "archaeological_warehouse__ContainerLocalisation")] -        ) -        self.assertEqual(len(loca_json), 3)          result_queryset = ContextRecord.objects.filter(              uuid=self.context_records[0].uuid) @@ -227,11 +217,6 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase):                   "archaeological_warehouse__WarehouseDivisionLink")]          )          self.assertEqual(len(div_json), 3) -        loca_json = json.loads( -            res[('warehouse', -                 "archaeological_warehouse__ContainerLocalisation")] -        ) -        self.assertEqual(len(loca_json), 3)          result_queryset = Find.objects.filter(uuid=self.finds[0].uuid)          res = self.generic_serialization_test( @@ -251,11 +236,6 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase):                   "archaeological_warehouse__WarehouseDivisionLink")]          )          self.assertEqual(len(div_json), 3) -        loca_json = json.loads( -            res[('warehouse', -                 "archaeological_warehouse__ContainerLocalisation")] -        ) -        self.assertEqual(len(loca_json), 3)          result_queryset = models.Warehouse.objects.filter(              id=self.warehouses[0].id) @@ -270,17 +250,12 @@ class SerializationTest(GenericSerializationTest, FindInit, TestCase):          container_json = json.loads(              res[('warehouse', "archaeological_warehouse__Container")]          ) -        self.assertEqual(len(container_json), 3) +        self.assertEqual(len(container_json), 3 + 3)          div_json = json.loads(              res[('warehouse',                   "archaeological_warehouse__WarehouseDivisionLink")]          )          self.assertEqual(len(div_json), 2) -        loca_json = json.loads( -            res[('warehouse', -                 "archaeological_warehouse__ContainerLocalisation")] -        ) -        self.assertEqual(len(loca_json), 2)      def test_ope_serialization_with_warehouse_filter(self):          res = self.generic_serialization_test( @@ -383,7 +358,7 @@ class WarehouseWizardCreationTest(WizardTest, FindInit, TestCase):                  },                  'divisions-warehouse_creation': [                      { -                        'division': None, +                        'container_type': None,                          'order': 42                      }                  ] @@ -410,8 +385,8 @@ class WarehouseWizardCreationTest(WizardTest, FindInit, TestCase):              models.WarehouseType.objects.all()[0].pk          alt_data['warehouse-warehouse_creation']['warehouse_type'] = \              models.WarehouseType.objects.all()[0].pk -        main_data['divisions-warehouse_creation'][0]['division'] = \ -            models.WarehouseDivision.create_default_for_test()[0].pk +        main_data['divisions-warehouse_creation'][0]['container_type'] = \ +            models.ContainerType.objects.all()[0].pk          self.warehouse_number = models.Warehouse.objects.count()          self.warehouse_div_link = models.WarehouseDivisionLink.objects.count()          super(WarehouseWizardCreationTest, self).pre_wizard() @@ -453,12 +428,13 @@ class ContainerWizardCreationTest(WizardTest, FindInit, TestCase):              },          ),          FormData( -            'Container creation with location', +            'Container creation',              form_datas={                  'container-container_creation': {                      'reference': 'hop-ref3',                      'container_type': None,                      'location': None, +                    'parent': None,                  },              },          ), @@ -482,24 +458,19 @@ class ContainerWizardCreationTest(WizardTest, FindInit, TestCase):              name="Alt",              warehouse_type=models.WarehouseType.objects.all()[0]          ) -        div = models.WarehouseDivision.objects.create(label='division') -        div_link = models.WarehouseDivisionLink.objects.create( -            warehouse=alt_warehouse, division=div)          alt_data['container-container_creation']["location"] = alt_warehouse.pk -        alt_data['localisation-container_creation'] = { -            'division_{}'.format(div_link.pk): 'Combien ?' -        } - +        alt_data['container-container_creation']["parent"] = \ +            models.Container.objects.create( +                reference="Plop", +                container_type=models.ContainerType.objects.all()[1], +                location=alt_warehouse +            ).pk          self.container_number = models.Container.objects.count() -        self.localisation_detail_number = \ -            models.ContainerLocalisation.objects.count()          super(ContainerWizardCreationTest, self).pre_wizard()      def post_wizard(self):          self.assertEqual(models.Container.objects.count(),                           self.container_number + 3) -        self.assertEqual(models.ContainerLocalisation.objects.count(), -                         self.localisation_detail_number + 1)  class WarehouseTest(TestCase): @@ -773,3 +744,82 @@ class ContainerTest(FindInit, TestCase):          find_lst = [f for f in container_1.finds.all()]          for f in [find0, find1, find2]:              self.assertIn(f, find_lst) + +    def test_merge_included_containers(self): +        ct = models.ContainerType.objects.all()[0] +        ct2 = models.ContainerType.objects.all()[1] +        ct3 = models.ContainerType.objects.all()[2] +        self.create_finds() +        self.create_finds() +        self.create_finds() +        self.create_finds() +        find0 = self.finds[0] +        find1 = self.finds[1] +        find2 = self.finds[2] +        find3 = self.finds[3] + +        top_container_1 = models.Container.objects.create( +            reference="Topref 1", +            location=self.main_warehouse, +            container_type=ct) +        find0.container = top_container_1 +        find0.container_ref = top_container_1 +        find0.save() + +        top_container_2 = models.Container.objects.create( +            reference="Topref 2", +            location=self.main_warehouse, +            container_type=ct) +        find1.container = top_container_2 +        find1.container_ref = top_container_2 +        find1.save() + +        middle_container_1 = models.Container.objects.create( +            reference="Middle ref", +            location=self.main_warehouse, +            parent=top_container_1, +            container_type=ct2) +        find2.container = middle_container_1 +        find2.container_ref = middle_container_1 +        find2.save() + +        middle_container_2 = models.Container.objects.create( +            reference="Middle ref", +            location=self.main_warehouse, +            parent=top_container_2, +            container_type=ct2) + +        bottom_container_3 = models.Container.objects.create( +            reference="Bottom ref", +            location=self.main_warehouse, +            parent=middle_container_2, +            container_type=ct3) +        find3.container = bottom_container_3 +        find3.container_ref = bottom_container_3 +        find3.save() + +        top_container_1.merge(top_container_2) + +        find0 = Find.objects.get(pk=find0.pk) +        self.assertEqual(find0.container, top_container_1) +        find1 = Find.objects.get(pk=find1.pk) +        self.assertEqual(find1.container, top_container_1) +        q = models.Container.objects.filter(reference="Topref 2") +        self.assertEqual(q.count(), 0) +        q = models.Container.objects.filter(reference="Topref 1") +        self.assertEqual(q.count(), 1) +        top_ref = q.all()[0] +        self.assertEqual(top_ref.finds.count(), 2) + +        q = models.Container.objects.filter(reference="Middle ref") +        self.assertEqual(q.count(), 1) +        middle = q.all()[0] +        self.assertEqual(middle.parent, top_ref) + +        q = models.Container.objects.filter(reference="Bottom ref") +        self.assertEqual(q.count(), 1) +        bottom = q.all()[0] +        self.assertEqual(bottom.parent, middle) + + + diff --git a/ishtar_common/model_merging.py b/ishtar_common/model_merging.py index 06f65378d..6b839a143 100644 --- a/ishtar_common/model_merging.py +++ b/ishtar_common/model_merging.py @@ -20,7 +20,8 @@ def get_models():  @transaction.atomic -def merge_model_objects(primary_object, alias_objects=None, keep_old=False): +def merge_model_objects(primary_object, alias_objects=None, keep_old=False, +                        exclude_fields=None):      """      Use this function to merge model objects (i.e. Users, Organizations,      etc.) and migrate all of the related fields from the alias objects to the @@ -34,9 +35,15 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):      """      if not alias_objects:          alias_objects = [] +    if not exclude_fields: +        exclude_fields = []      MERGE_FIELDS = ('merge_candidate', 'merge_exclusion') +    MERGE_STRING_FIELDS = [] +    if getattr(primary_object, "MERGE_STRING_FIELDS", None): +        MERGE_STRING_FIELDS = primary_object.MERGE_STRING_FIELDS +      if not isinstance(alias_objects, list):          alias_objects = [alias_objects] @@ -80,6 +87,11 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):              alias_varname = related_object.get_accessor_name()              # The variable name on the related model.              obj_varname = related_object.field.name +            if obj_varname in exclude_fields: +                continue +            if getattr(related_object.field, "related_model", None) and \ +                    not related_object.related_model._meta.managed: +                continue              try:                  related_objects = getattr(alias_object, alias_varname)              except ObjectDoesNotExist: @@ -128,7 +140,7 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):              many_to_many_objects = getattr(alias_object, alias_varname).all()              if alias_varname in blank_local_fields: -                blank_local_fields.pop(alias_varname) +                blank_local_fields.remove(alias_varname)              for obj in many_to_many_objects.all():                  getattr(alias_object, alias_varname).remove(obj)                  getattr(primary_object, alias_varname).add(obj) @@ -142,9 +154,21 @@ def merge_model_objects(primary_object, alias_objects=None, keep_old=False):                  alias_object)              for generic_related_object in field.model.objects.filter(                      **filter_kwargs): +                if field.name in exclude_fields: +                    continue                  setattr(generic_related_object, field.name, primary_object)                  generic_related_object.save() +        for field_name in MERGE_STRING_FIELDS: +            if getattr(primary_object, field_name) and \ +                    getattr(alias_object, field_name): +                val = "{} ; {}".format( +                    getattr(primary_object, field_name), +                    getattr(alias_object, field_name)) +                if field_name in exclude_fields: +                    continue +                setattr(primary_object, field_name, val) +          # Try to fill all missing values in primary object by values of          # duplicates          filled_up = set() diff --git a/ishtar_common/models.py b/ishtar_common/models.py index 9011c4638..a6cfcf697 100644 --- a/ishtar_common/models.py +++ b/ishtar_common/models.py @@ -4286,8 +4286,9 @@ class Merge(models.Model):          for m in self.merge_exclusion.all():              m.delete() -    def merge(self, item, keep_old=False): -        merge_model_objects(self, item, keep_old=keep_old) +    def merge(self, item, keep_old=False, exclude_fields=None): +        merge_model_objects(self, item, keep_old=keep_old, +                            exclude_fields=exclude_fields)          self.generate_merge_candidate() diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py index 122f29d92..6d112f2d1 100644 --- a/ishtar_common/tests.py +++ b/ishtar_common/tests.py @@ -682,6 +682,8 @@ class GenericSerializationTest:                  module = importlib.import_module(module_name + ".models")              model = getattr(module, model_name) +            if getattr(model, "TO_BE_DELETED", False): +                continue              current_count = model.objects.count()              result = json.loads(json_result[key])              serialization_count = len(result) @@ -917,11 +919,11 @@ class SerializationTest(GenericSerializationTest, TestCase):          )          wl1 = WarehouseDivisionLink.objects.create(              warehouse=w1, -            division=wd1 +            container_type=ContainerType.objects.all()[0],          )          wl2 = WarehouseDivisionLink.objects.create(              warehouse=w2, -            division=wd2 +            container_type = ContainerType.objects.all()[1],          )          ContainerLocalisation.objects.create(              container=c1, | 
