diff options
-rw-r--r-- | archaeological_finds/models_finds.py | 6 | ||||
-rw-r--r-- | archaeological_finds/tests.py | 14 |
2 files changed, 20 insertions, 0 deletions
diff --git a/archaeological_finds/models_finds.py b/archaeological_finds/models_finds.py index f4056cf4b..5b980a590 100644 --- a/archaeological_finds/models_finds.py +++ b/archaeological_finds/models_finds.py @@ -1993,6 +1993,10 @@ class Find(BulkUpdatedItem, ValueGetter, DocumentItem, BaseHistorizedItem, if data: for k in data: setattr(new, k, data[k]) + # remove associated treatments + if not duplicate_for_treatment and ( + new.upstream_treatment or new.downstream_treatment): + new.upstream_treatment, new.downstream_treatment = None, None new.save() # m2m fields @@ -2023,6 +2027,8 @@ class Find(BulkUpdatedItem, ValueGetter, DocumentItem, BaseHistorizedItem, user=user, data={"label": new.label, "external_id": None})) # remove documents for this kind of duplicate (data entry) new.documents.clear() + # remove associated treatments + new.treatments.clear() return new @classmethod diff --git a/archaeological_finds/tests.py b/archaeological_finds/tests.py index e39ebda32..2c78d3efa 100644 --- a/archaeological_finds/tests.py +++ b/archaeological_finds/tests.py @@ -1207,12 +1207,22 @@ class FindQATest(FindInit, TestCase): codename='change_find')) def test_duplicate(self): + t1, __ = models.Treatment.objects.get_or_create( + label="Treatment 1", + treatment_state=models.TreatmentState.objects.all()[0] + ) + t2, __ = models.Treatment.objects.get_or_create( + label="Treatment 1", + treatment_state=models.TreatmentState.objects.all()[0] + ) find = self.finds[0] default_desc = "Description for duplicate" find.description = default_desc + find.upstream_treatment = t1 find.save() d = Document.objects.create() d.finds.add(find) + find.treatments.add(t2) c = Client() url = reverse('find-qa-duplicate', args=[find.pk]) @@ -1245,6 +1255,10 @@ class FindQATest(FindInit, TestCase): base_bf = find.get_first_base_find() self.assertEqual(new_bf.context_record, base_bf.context_record) self.assertNotIn(d, list(new.documents.all())) + # associated treatments are removed #4850 + self.assertIsNone(new.upstream_treatment) + self.assertIsNone(new.downstream_treatment) + self.assertNotIn(t2, list(new.treatments.all())) def test_bulk_update(self): c = Client() |