diff options
| -rw-r--r-- | archaeological_finds/models_finds.py | 6 | ||||
| -rw-r--r-- | archaeological_finds/tests.py | 14 | 
2 files changed, 20 insertions, 0 deletions
| diff --git a/archaeological_finds/models_finds.py b/archaeological_finds/models_finds.py index f4056cf4b..5b980a590 100644 --- a/archaeological_finds/models_finds.py +++ b/archaeological_finds/models_finds.py @@ -1993,6 +1993,10 @@ class Find(BulkUpdatedItem, ValueGetter, DocumentItem, BaseHistorizedItem,          if data:              for k in data:                  setattr(new, k, data[k]) +        # remove associated treatments +        if not duplicate_for_treatment and ( +                new.upstream_treatment or new.downstream_treatment): +            new.upstream_treatment, new.downstream_treatment = None, None          new.save()          # m2m fields @@ -2023,6 +2027,8 @@ class Find(BulkUpdatedItem, ValueGetter, DocumentItem, BaseHistorizedItem,                      user=user, data={"label": new.label, "external_id": None}))              # remove documents for this kind of duplicate (data entry)              new.documents.clear() +            # remove associated treatments +            new.treatments.clear()          return new      @classmethod diff --git a/archaeological_finds/tests.py b/archaeological_finds/tests.py index e39ebda32..2c78d3efa 100644 --- a/archaeological_finds/tests.py +++ b/archaeological_finds/tests.py @@ -1207,12 +1207,22 @@ class FindQATest(FindInit, TestCase):              codename='change_find'))      def test_duplicate(self): +        t1, __ = models.Treatment.objects.get_or_create( +            label="Treatment 1", +            treatment_state=models.TreatmentState.objects.all()[0] +        ) +        t2, __ = models.Treatment.objects.get_or_create( +            label="Treatment 1", +            treatment_state=models.TreatmentState.objects.all()[0] +        )          find = self.finds[0]          default_desc = "Description for duplicate"          find.description = default_desc +        find.upstream_treatment = t1          find.save()          d = Document.objects.create()          d.finds.add(find) +        find.treatments.add(t2)          c = Client()          url = reverse('find-qa-duplicate', args=[find.pk]) @@ -1245,6 +1255,10 @@ class FindQATest(FindInit, TestCase):          base_bf = find.get_first_base_find()          self.assertEqual(new_bf.context_record, base_bf.context_record)          self.assertNotIn(d, list(new.documents.all())) +        # associated treatments are removed #4850 +        self.assertIsNone(new.upstream_treatment) +        self.assertIsNone(new.downstream_treatment) +        self.assertNotIn(t2, list(new.treatments.all()))      def test_bulk_update(self):          c = Client() | 
