diff options
| -rw-r--r-- | archaeological_context_records/models.py | 42 | ||||
| -rw-r--r-- | archaeological_context_records/tests.py | 97 | 
2 files changed, 76 insertions, 63 deletions
diff --git a/archaeological_context_records/models.py b/archaeological_context_records/models.py index 65145be98..454e83e8a 100644 --- a/archaeological_context_records/models.py +++ b/archaeological_context_records/models.py @@ -1454,14 +1454,23 @@ class ContextRecordTree(RelationsViews):          return new_trees      @classmethod -    def _get_equals(cls, item_id, equal_rel_types): -        equals = list(RecordRelations.objects.values_list( +    def _get_equals(cls, item_id, equal_rel_types, exclude=None): +        if not exclude: +            exclude = [item_id] +        q = RecordRelations.objects.values_list(              "right_record_id", flat=True).filter( -            left_record_id=item_id, relation_type_id__in=equal_rel_types)) -        equals += list(RecordRelations.objects.values_list( +            left_record_id=item_id, relation_type_id__in=equal_rel_types) +        q = q.exclude(right_record_id__in=exclude) +        equals = list(q) +        q = RecordRelations.objects.values_list(              "left_record_id", flat=True).filter( -            right_record_id=item_id, relation_type_id__in=equal_rel_types)) -        return set(equals) +            right_record_id=item_id, relation_type_id__in=equal_rel_types) +        q = q.exclude(left_record_id__in=exclude) +        equals += list(q) +        exclude += equals +        for eq_id in equals: +            equals += cls._get_equals(eq_id, equal_rel_types, exclude=exclude) +        return equals      @classmethod      def _update_equals(cls, item_id, equals): @@ -1485,24 +1494,32 @@ class ContextRecordTree(RelationsViews):          equal_rel_types = cls._get_base_equal_relations()          keys = []          for child_id, parent_id in relations: -            equals = set(cls._get_equals(child_id, equal_rel_types)) -            keys += cls._update_equals(child_id, equals) -            for alt_child in equals: +            equals_child = set(cls._get_equals(child_id, equal_rel_types)) +            keys += cls._update_equals(child_id, equals_child) +            for alt_child in equals_child:                  if alt_child != child_id:                      cls.objects.get_or_create(                          key=f"{alt_child}_{parent_id}",                          cr_id=alt_child, cr_parent_id=parent_id                      )                      keys.append((alt_child, parent_id)) -            equals = set(cls._get_equals(parent_id, equal_rel_types)) -            keys += cls._update_equals(parent_id, equals) -            for alt_parent in equals: +            equals_parent = set(cls._get_equals(parent_id, equal_rel_types)) +            keys += cls._update_equals(parent_id, equals_parent) +            for alt_parent in equals_parent:                  if alt_parent != parent_id:                      cls.objects.get_or_create(                          key=f"{child_id}_{alt_parent}",                          cr_id=child_id, cr_parent_id=alt_parent                      )                      keys.append((child_id, alt_parent)) + +                    for alt_child in equals_child: +                        if alt_child != child_id: +                            cls.objects.get_or_create( +                                key=f"{alt_child}_{alt_parent}", +                                cr_id=alt_child, cr_parent_id=alt_parent +                            ) +                            keys.append((alt_child, alt_parent))          return set(keys)      @classmethod @@ -1541,7 +1558,6 @@ class ContextRecordTree(RelationsViews):              equals = set(cls._get_equals(item_id, equal_rel_types))              all_relations.update(cls._update_equals(item_id, equals)) -          ## delete old relations          if not already_updated:              already_updated = [item_id] diff --git a/archaeological_context_records/tests.py b/archaeological_context_records/tests.py index 553731542..6b3e9f648 100644 --- a/archaeological_context_records/tests.py +++ b/archaeological_context_records/tests.py @@ -884,7 +884,7 @@ class RecordRelationsTest(ContextRecordInit, TestCase):      model = models.ContextRecord      def setUp(self): -        for idx in range(1, 11): +        for idx in range(1, 15):              self.create_context_record({"label": f"CR {idx}"})      def test_relations(self): @@ -949,6 +949,7 @@ class RecordRelationsTest(ContextRecordInit, TestCase):          profile = get_current_profile(force=True)          models.ContextRecordTree.check_engine()          crs = self.context_records +          rel_type_1 = models.RelationType.objects.create(              symmetrical=False, txt_idx="rel_1",              logical_relation='included' @@ -978,6 +979,19 @@ class RecordRelationsTest(ContextRecordInit, TestCase):                  right_record=crs[parent_idx - 1],                  relation_type=rel_type_1              ) +        rel_type_2 = models.RelationType.objects.create( +            symmetrical=True, txt_idx="rel_2", +            logical_relation='equal' +        ) +        equal_relations = ( +            (10, 11), (11, 12), (5, 13), (3, 14) +        ) +        for child_idx, parent_idx in equal_relations: +            models.RecordRelations.objects.create( +                left_record=crs[child_idx - 1], +                right_record=crs[parent_idx - 1], +                relation_type=rel_type_2 +            )          q = models.ContextRecordTree.objects.filter(              cr_parent_id=crs[2].pk, cr_id=crs[0].pk)          self.assertGreaterEqual(q.count(), 1) @@ -992,6 +1006,11 @@ class RecordRelationsTest(ContextRecordInit, TestCase):          # verify tree generation          full_trees = [              [10, 5, 3, 2], +            [11, 5, 3, 2], +            [12, 5, 3, 2], +            [12, 13, 14, 2], +            [10, 5, 14, 2], +            [10, 14],              [10, 5, 3, 1],              [9, 5, 3, 2],              [9, 5, 3, 1], @@ -1004,53 +1023,6 @@ class RecordRelationsTest(ContextRecordInit, TestCase):          ]          self._test_tree_generation(0, full_trees)          trees = [ -            [10, 5, 3, 2], -            [10, 5, 3, 1], -            [9, 5, 3, 2], -            [9, 5, 3, 1], -            [8, 4, 3, 2], -            [8, 4, 3, 1], -            [7, 4, 3, 2], -            [7, 4, 3, 1], -            [6, 4, 3, 2], -            [6, 4, 3, 1], -        ] -        self._test_tree_generation(1, trees) -        trees = [ -            [10, 5, 3, 2], -            [10, 5, 3, 1], -            [9, 5, 3, 2], -            [9, 5, 3, 1], -            [8, 4, 3, 2], -            [8, 4, 3, 1], -            [7, 4, 3, 2], -            [7, 4, 3, 1], -            [6, 4, 3, 2], -            [6, 4, 3, 1], -        ] -        self._test_tree_generation(2, trees) -        trees = [ -            [8, 4, 3, 2], -            [8, 4, 3, 1], -            [7, 4, 3, 2], -            [7, 4, 3, 1], -            [6, 4, 3, 2], -            [6, 4, 3, 1], -        ] -        self._test_tree_generation(3, trees) -        trees = [ -            [10, 5, 3, 2], -            [10, 5, 3, 1], -            [9, 5, 3, 2], -            [9, 5, 3, 1], -        ] -        self._test_tree_generation(4, trees) -        trees = [ -            [6, 4, 3, 2], -            [6, 4, 3, 1], -        ] -        self._test_tree_generation(5, trees) -        trees = [              [7, 4, 3, 2],              [7, 4, 3, 1],          ] @@ -1075,8 +1047,33 @@ class RecordRelationsTest(ContextRecordInit, TestCase):          models.ContextRecordTree.regenerate_all()          self._test_tree_(full_trees, "'FULL GENERATION'") -        # test remove a Node -        # test EQUIV +        # test remove a node +        nb = models.ContextRecordTree.objects.filter( +            cr_parent=crs[6 - 1], cr=crs[3 - 1]).count() +        self.assertEqual(nb, 1) +        models.RecordRelations.objects.filter( +            left_record=crs[3 - 1], +            right_record=crs[4 - 1] +        ).delete() +        models.ContextRecordTree.update(crs[3 - 1].pk) +        models.ContextRecordTree.update(crs[4 - 1].pk) +        nb = models.ContextRecordTree.objects.filter( +            cr_parent=crs[6 - 1], cr=crs[3 - 1]).count() +        self.assertEqual(nb, 0) + +        # test remove a node (update equal links) +        nb = models.ContextRecordTree.objects.filter( +            cr_parent=crs[10 - 1], cr=crs[14 - 1]).count() +        self.assertEqual(nb, 1) +        models.RecordRelations.objects.filter( +            left_record=crs[3 - 1], +            right_record=crs[5 - 1] +        ).delete() +        models.ContextRecordTree.update(crs[3 - 1].pk) +        models.ContextRecordTree.update(crs[5 - 1].pk) +        nb = models.ContextRecordTree.objects.filter( +            cr_parent=crs[10 - 1], cr=crs[14 - 1]).count() +        self.assertEqual(nb, 0)      def _test_tree_(self, test_trees, context_record):          crs = self.context_records  | 
