diff options
-rw-r--r-- | ishtar_common/models_common.py | 73 | ||||
-rw-r--r-- | ishtar_common/tests.py | 77 |
2 files changed, 135 insertions, 15 deletions
diff --git a/ishtar_common/models_common.py b/ishtar_common/models_common.py index d32f835bd..adf1c1516 100644 --- a/ishtar_common/models_common.py +++ b/ishtar_common/models_common.py @@ -2581,18 +2581,13 @@ class GeoVectorData(Imported, OwnPerms): post_save.connect(post_save_geodata, sender=GeoVectorData) -def geodata_attached_changed(sender, **kwargs): - # manage main geoitem and cascade association - instance = kwargs.get("instance", None) - model = kwargs.get("model", None) - pk_set = kwargs.get("pk_set", None) - if not instance or not model or not pk_set: - return +def geodata_attached_post_add(model, instance, pk_set): item_pks = list(model.objects.filter(pk__in=pk_set).values_list("pk", flat=True)) if not item_pks: return - if not hasattr(instance, "_geodata"): # use a cache to manage during geodata attach + # use a cache to manage during geodata attach + if not hasattr(instance, "_geodata"): instance._geodata = [] if not instance.main_geodata_id: instance.main_geodata_id = item_pks[0] @@ -2617,11 +2612,71 @@ def geodata_attached_changed(sender, **kwargs): if not q.count(): if not child: child = model.objects.get(pk=pk) - if not pk in geoitems: + if pk not in geoitems: geoitems[pk] = GeoVectorData.objects.get(pk=pk) child_model.objects.get(pk=child_id).geodata.add(geoitems[pk]) +def geodata_attached_remove(model, instance, pk_set=None, clear=False): + if clear: + item_pks = getattr(instance, "_geodata_clear_item_pks", []) + else: + item_pks = list(model.objects.filter(pk__in=pk_set).values_list("pk", flat=True)) + if not item_pks: + return + + # use a cache to manage during geodata attach + if not hasattr(instance, "_geodata"): + instance._geodata = [] + if instance.main_geodata_id in item_pks: + instance.main_geodata_id = None + instance.skip_history_when_saving = True + instance._no_move = True + if not hasattr(instance, "_geodata"): + instance._geodata = [] + instance._geodata += [pk for pk in item_pks if pk not in instance._geodata] + instance.save() + + # for all sub item verify that the geo items are present + for query in instance.geodata_child_item_queries(): + child_model = query.model + m2m_model = child_model.geodata.through + m2m_key = f"{child_model._meta.model_name}_id" + geoitems = {} + for child_id in query.values_list("id", flat=True): + child = None + for pk in item_pks: + q = m2m_model.objects.filter(**{m2m_key: child_id, + "geovectordata_id": pk}) + if q.count(): + if not child: + child = model.objects.get(pk=pk) + if pk not in geoitems: + geoitems[pk] = GeoVectorData.objects.get(pk=pk) + child_model.objects.get(pk=child_id).geodata.remove(geoitems[pk]) + + +def geodata_attached_changed(sender, **kwargs): + # manage main geoitem and cascade association + instance = kwargs.get("instance", None) + model = kwargs.get("model", None) + pk_set = kwargs.get("pk_set", None) + action = kwargs.get("action", None) + if not instance or not model: + return + + if action == "post_add": + geodata_attached_post_add(model, instance, pk_set) + elif action == "post_remove": + geodata_attached_remove(model, instance, pk_set) + elif action == "pre_clear": + instance._geodata_clear_item_pks = list( + instance.geodata.values_list("id", flat=True) + ) + elif action == "post_clear": + geodata_attached_remove(model, instance, clear=True) + + class GeographicItem(models.Model): main_geodata = models.ForeignKey( GeoVectorData, diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py index 74e746c61..4156f9f0f 100644 --- a/ishtar_common/tests.py +++ b/ishtar_common/tests.py @@ -3372,6 +3372,16 @@ class GeoVectorTest(TestCase): app_label=self.app_source, model=self.model_source ).pk + self.alt_data_type = models.GeoDataType.objects.get( + txt_idx="basefind-center", + ) + self.alt_app_source = "archaeological_finds" + self.alt_model_source = "basefind" + self.alt_source_pk = self.base_find.pk + self.alt_source_content_type_pk = ContentType.objects.get( + app_label=self.alt_app_source, + model=self.alt_model_source + ).pk def _reinit_objects(self): # get object from db @@ -3383,13 +3393,16 @@ class GeoVectorTest(TestCase): BaseFind = apps.get_model("archaeological_finds", "BaseFind") self.base_find = BaseFind.objects.get(pk=self.base_find.pk) - def _create_geodata(self): + def _create_geodata(self, alt=False): + ct = self.source_content_type_pk if not alt else self.alt_source_content_type_pk + source = self.source_pk if not alt else self.alt_source_pk + dt = self.data_type if not alt else self.alt_data_type return models.GeoVectorData.objects.create( - source_content_type_id=self.source_content_type_pk, - source_id=self.source_pk, + source_content_type_id=ct, + source_id=source, name="Test geo", origin=self.origin, - data_type=self.data_type, + data_type=dt, provider=self.provider, comment="This is a comment." ) @@ -3402,6 +3415,9 @@ class GeoVectorTest(TestCase): self.assertIsNone(self.base_find.main_geodata) self.assertEqual(self.base_find.geodata.count(), 0) + geo_vector_find = self._create_geodata(alt=True) + self.base_find.geodata.add(geo_vector_find) + geo_vector = self._create_geodata() self.operation.geodata.add(geo_vector) @@ -3410,10 +3426,59 @@ class GeoVectorTest(TestCase): self.assertEqual(self.operation.main_geodata, geo_vector) self.assertEqual(self.context_record.geodata.count(), 1) self.assertEqual(self.context_record.main_geodata, geo_vector) + self.assertEqual(self.base_find.geodata.count(), 2) + self.assertEqual(self.base_find.main_geodata, geo_vector_find) + + geo_vector2 = self._create_geodata() + self.operation.geodata.add(geo_vector2) + + self._reinit_objects() + self.assertEqual(self.operation.geodata.count(), 2) + self.assertEqual(self.operation.main_geodata, geo_vector) # no change + self.assertEqual(self.context_record.geodata.count(), 2) + self.assertEqual(self.context_record.main_geodata, geo_vector) # no change + self.assertEqual(self.base_find.geodata.count(), 3) + self.assertEqual(self.base_find.main_geodata, geo_vector_find) # no change + + def test_cascade_remove(self): + geo_vector = self._create_geodata() + self.operation.geodata.add(geo_vector) + geo_vector2 = self._create_geodata() + self.operation.geodata.add(geo_vector2) + geo_vector_find = self._create_geodata(alt=True) + self.base_find.geodata.add(geo_vector_find) + + self.operation.geodata.remove(geo_vector) + self._reinit_objects() + # main geoitem changed to geovector2 + self.assertEqual(self.operation.main_geodata, geo_vector2) + self.assertEqual(self.operation.geodata.count(), 1) + self.assertEqual(self.context_record.main_geodata, geo_vector2) + self.assertEqual(self.context_record.geodata.count(), 1) + self.assertEqual(self.base_find.main_geodata, geo_vector2) + self.assertEqual(self.base_find.geodata.count(), 2) + + self.operation.geodata.remove(geo_vector2) + self._reinit_objects() + self.assertIsNone(self.operation.main_geodata) + self.assertEqual(self.operation.geodata.count(), 0) + self.assertIsNone(self.context_record.main_geodata) + self.assertEqual(self.context_record.geodata.count(), 0) + self.assertEqual(self.base_find.main_geodata, geo_vector_find) + self.assertEqual(self.base_find.geodata.count(), 1) + + self.operation.geodata.add(geo_vector) + self.operation.geodata.add(geo_vector2) + + self.operation.geodata.clear() + self._reinit_objects() + self.assertIsNone(self.operation.main_geodata) + self.assertEqual(self.operation.geodata.count(), 0) + self.assertIsNone(self.context_record.main_geodata) + self.assertEqual(self.context_record.geodata.count(), 0) + self.assertEqual(self.base_find.main_geodata, geo_vector_find) self.assertEqual(self.base_find.geodata.count(), 1) - self.assertEqual(self.base_find.main_geodata, geo_vector) - # test geo item remove # test town add # test town remove |