diff options
| -rw-r--r-- | ishtar_common/serializers.py | 96 | ||||
| -rw-r--r-- | ishtar_common/tests.py | 24 | 
2 files changed, 84 insertions, 36 deletions
diff --git a/ishtar_common/serializers.py b/ishtar_common/serializers.py index 2069c7e38..187686321 100644 --- a/ishtar_common/serializers.py +++ b/ishtar_common/serializers.py @@ -1,3 +1,4 @@ +from copy import deepcopy  import datetime  import json  import importlib @@ -41,21 +42,62 @@ def serialization_info():      } -def type_serialization(archive=False, return_empty_types=False, -                       archive_name=None): +def archive_serialization(result, archive_dir=None, archive=False, +                          return_empty_types=False, archive_name=None):      """      Serialize all types models to JSON      Used for import and export scripts +    :param result: serialization results +    :param archive_dir: directory inside the archive (default None)      :param return_empty_types: if True instead of serialization return empty      types (default False)      :param archive: if True return a zip file containing all the file serialized -    (defaukt False) +    (default False)      :return: string containing the json serialization of types unless      return_empty_types or archive is set to True      """      if archive and return_empty_types:          raise ValueError("archive and return_empty_types are incompatible") +    if return_empty_types: +        return [k for k in result if not result[k]] +    if not archive: +        return result +    archive_created = False +    if not archive_name: +        archive_created = True +        tmpdir = tempfile.mkdtemp(prefix="ishtarexport-") + os.sep +        archive_name = tmpdir + "ishtar-{}.zip".format( +            datetime.date.today().strftime("%Y-%m-%d") +        ) +    if not archive_name.endswith(".zip"): +        archive_name += ".zip" +    with tempfile.TemporaryDirectory() as tmpdirname: +        if archive_dir: +            os.mkdir(tmpdirname + os.sep + archive_dir) + +        with ZipFile(archive_name, 'w') as current_zip: +            if archive_created: +                base_filename = "info.json" +                filename = tmpdirname + os.sep + base_filename +                with open(filename, "w") as json_file: +                    json_file.write( +                        json.dumps(serialization_info(), indent=2) +                    ) +                current_zip.write(filename, arcname=base_filename) + +            for model_name in result: +                base_filename = model_name + ".json" +                filename = tmpdirname + os.sep + base_filename +                with open(filename, "w") as json_file: +                    json_file.write(result[model_name]) +                current_zip.write(filename, +                                  arcname="types" + os.sep + base_filename) +    return archive_name + + +def type_serialization(archive=False, return_empty_types=False, +                       archive_name=None):      result = {}      for model in apps.get_models():          if not isinstance(model(), models.GeneralType): @@ -90,37 +132,23 @@ def type_serialization(archive=False, return_empty_types=False,                  serialized += [item["id"] for item in q.values("id").all()]                  q = base_q.filter(**{recursion + "_id__in": serialized}                                    ).exclude(id__in=serialized) -    if return_empty_types: -        return [k for k in result if not result[k]] -    if not archive: -        return result -    if not archive_name: -        tmpdir = tempfile.mkdtemp(prefix="ishtarexport-") + os.sep -        archive_name = tmpdir + "ishtar-{}.zip".format( -            datetime.date.today().strftime("%Y-%m-%d") -        ) -    if not archive_name.endswith(".zip"): -        archive_name += ".zip" -    with tempfile.TemporaryDirectory() as tmpdirname: -        os.mkdir(tmpdirname + os.sep + "types") - -        with ZipFile(archive_name, 'w') as current_zip: -            base_filename = "info.json" -            filename = tmpdirname + os.sep + base_filename -            with open(filename, "w") as json_file: -                json_file.write( -                    json.dumps(serialization_info(), indent=2) -                ) -            current_zip.write(filename, arcname=base_filename) - -            for model_name in result: -                base_filename = model_name + ".json" -                filename = tmpdirname + os.sep + base_filename -                with open(filename, "w") as json_file: -                    json_file.write(result[model_name]) -                current_zip.write(filename, -                                  arcname="types" + os.sep + base_filename) -    return archive_name +            # managed circular +            q = base_q.exclude(id__in=serialized) +            if q.count(): +                v = serialize( +                    "json", q.all(), indent=2, use_natural_foreign_keys=True, +                    use_natural_primary_keys=True) +                result_to_add = json.loads(v) +                result_cleaned = deepcopy(result_to_add) +                for res in result_cleaned:  # first add with no recursion +                    res["fields"][recursion] = None +                new_result = json.loads(result[model_name]) +                new_result += result_cleaned +                new_result += result_to_add +                result[model_name] = json.dumps(new_result, indent=2) +    return archive_serialization(result, archive_dir="types", archive=archive, +                                 return_empty_types=return_empty_types, +                                 archive_name=archive_name)  def restore_serialized(archive_name, delete_existing=False): diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py index e292ae097..7e1948bf0 100644 --- a/ishtar_common/tests.py +++ b/ishtar_common/tests.py @@ -604,8 +604,15 @@ class SerializationTest(TestCase):              module_name, model_name = k.split("__")              module = importlib.import_module(module_name + ".models")              model = getattr(module, model_name) -            self.assertEqual(model.objects.count(), -                             len(json.loads(json_result[k]))) +            current_count = model.objects.count() +            serialization_count = len(json.loads(json_result[k])) +            # has to be at least equal (can be superior for model with +            # recursivity) +            self.assertTrue( +                serialization_count >= current_count, +                msg="Serialization for model {}.{} failed. {} serialized {} " +                    "expected".format(module.__name__, model_name, +                                      serialization_count, current_count))      def test_serialization_zip(self):          zip_filename = type_serialization(archive=True) @@ -632,6 +639,11 @@ class SerializationTest(TestCase):              restore_serialized(zip_filename)      def test_type_restore(self): +        from archaeological_context_records.models import RelationType as CRRT +        from archaeological_operations.models import RelationType as OperationRT +        cr_rel_type_nb = CRRT.objects.count() +        ope_rel_type_nb = OperationRT.objects.count() +          models.AuthorType.objects.create(label="Test", txt_idx="test")          zip_filename = type_serialization(archive=True) @@ -644,6 +656,10 @@ class SerializationTest(TestCase):          self.assertEqual(              models.AuthorType.objects.filter(txt_idx="am-i-still-here").count(),              1) +        self.assertEqual(cr_rel_type_nb, CRRT.objects.count()) +        self.assertEqual(ope_rel_type_nb, OperationRT.objects.count()) +        self.assertTrue(OperationRT.objects.filter( +            inverse_relation__isnull=False).count())          models.AuthorType.objects.filter(txt_idx="am-i-still-here").delete()          zip_filename = type_serialization(archive=True) @@ -656,6 +672,10 @@ class SerializationTest(TestCase):          self.assertEqual(              models.AuthorType.objects.filter(txt_idx="am-i-still-here").count(),              0) +        self.assertEqual(cr_rel_type_nb, CRRT.objects.count()) +        self.assertEqual(ope_rel_type_nb, OperationRT.objects.count()) +        self.assertTrue(OperationRT.objects.filter( +            inverse_relation__isnull=False).count())  class AccessControlTest(TestCase):  | 
