diff options
Diffstat (limited to 'ishtar_common')
| -rw-r--r-- | ishtar_common/serializers.py | 81 | ||||
| -rw-r--r-- | ishtar_common/tests.py | 67 | 
2 files changed, 102 insertions, 46 deletions
diff --git a/ishtar_common/serializers.py b/ishtar_common/serializers.py index 4867b58f6..6d25efdb6 100644 --- a/ishtar_common/serializers.py +++ b/ishtar_common/serializers.py @@ -103,14 +103,12 @@ def archive_serialization(result, archive_dir=None, archive=False,      return archive_name -def type_serialization(archive=False, return_empty_types=False, -                       archive_name=None): +def generic_get_results(model_list, dirname, no_geo=True):      result = OrderedDict() -    for model in apps.get_models(): -        if not isinstance(model(), models.GeneralType): -            continue +    for model in model_list:          model_name = model.__name__          model_name = str(model.__module__).split(".")[0] + "__" + model_name +          base_q = model.objects          q = base_q          recursion = None @@ -118,16 +116,25 @@ def type_serialization(archive=False, return_empty_types=False,              recursion = "parent"          elif hasattr(model, "inverse_relation"):              recursion = "inverse_relation" +        elif hasattr(model, "children"): +            recursion = "children__id"          if recursion:              q = q.filter(**{recursion + "__isnull": True}) -        key = ("types", model_name) + +        key = (dirname, model_name)          result[key] = serialize( -            "json", q.all(), indent=2, -            use_natural_foreign_keys=True, use_natural_primary_keys=True +            "json", q.all(), +            indent=2, +            use_natural_foreign_keys=True, use_natural_primary_keys=True,          ) +          if recursion:              serialized = [item["id"] for item in q.values("id").all()] -            q = base_q.filter(**{recursion + "_id__in": serialized} +            recursion_in = recursion +            if not recursion.endswith("_id"): +                recursion_in += "_id" +            recursion_in += "__in" +            q = base_q.filter(**{recursion_in: serialized}                                ).exclude(id__in=serialized)              while q.count():                  v = serialize( @@ -137,7 +144,7 @@ def type_serialization(archive=False, return_empty_types=False,                  new_result += json.loads(v)                  result[key] = json.dumps(new_result, indent=2)                  serialized += [item["id"] for item in q.values("id").all()] -                q = base_q.filter(**{recursion + "_id__in": serialized} +                q = base_q.filter(**{recursion_in: serialized}                                    ).exclude(id__in=serialized)              # managed circular              q = base_q.exclude(id__in=serialized) @@ -153,27 +160,20 @@ def type_serialization(archive=False, return_empty_types=False,                  new_result += result_cleaned                  new_result += result_to_add                  result[key] = json.dumps(new_result, indent=2) -    return archive_serialization(result, archive_dir="types", archive=archive, -                                 return_empty_types=return_empty_types, -                                 archive_name=archive_name) - -def generic_get_results(model_list, dirname): -    result = OrderedDict() -    for model in model_list: -        model_name = model.__name__ -        model_name = str(model.__module__).split(".")[0] + "__" + model_name -        key = (dirname, model_name) -        result[key] = serialize( -            "json", model.objects.all(), -            indent=2, -            use_natural_foreign_keys=True, use_natural_primary_keys=True, -        ) +        excluded_fields = []          if hasattr(model, "SERIALIZATION_EXCLUDE"): +            excluded_fields = list(model.SERIALIZATION_EXCLUDE) +        if no_geo: +            excluded_fields += ["center", "limit"] + [ +                field.name for field in models.GeoItem._meta.get_fields() +            ] +        if excluded_fields:              new_result = json.loads(result[key])              for idx in range(len(new_result)): -                for excluded_field in model.SERIALIZATION_EXCLUDE: -                    new_result[idx]["fields"].pop(excluded_field) +                for excluded_field in excluded_fields: +                    if excluded_field in new_result[idx]["fields"]: +                        new_result[idx]["fields"].pop(excluded_field)              result[key] = json.dumps(new_result, indent=2)      return result @@ -201,6 +201,18 @@ def generic_archive_files(model_list, archive_name=None):      return archive_name +def type_serialization(archive=False, return_empty_types=False, +                       archive_name=None): +    TYPE_MODEL_LIST = [ +        model for model in apps.get_models() +        if isinstance(model(), models.GeneralType) +    ] +    result = generic_get_results(TYPE_MODEL_LIST, "types") +    return archive_serialization(result, archive_dir="types", archive=archive, +                                 return_empty_types=return_empty_types, +                                 archive_name=archive_name) + +  CONF_MODEL_LIST = [      models.IshtarSiteProfile, models.GlobalVar, models.CustomForm,      models.ExcludedField, models.JsonDataSection, models.JsonDataField, @@ -243,6 +255,20 @@ def importer_serialization(archive=False, return_empty_types=False,      return full_archive +GEO_MODEL_LIST = [ +    models.State, models.Department, models.Town +] + + +def geo_serialization(archive=False, return_empty_types=False, +                      archive_name=None, no_geo=True): +    result = generic_get_results(GEO_MODEL_LIST, "common_geo", no_geo=no_geo) +    full_archive = archive_serialization( +        result, archive_dir="common_geo", archive=archive, +        return_empty_types=return_empty_types, archive_name=archive_name) +    return full_archive + +  def restore_serialized(archive_name, delete_existing=False):      with zipfile.ZipFile(archive_name, "r") as zip_file:          # check version @@ -256,6 +282,7 @@ def restore_serialized(archive_name, delete_existing=False):          DIRS = (              ("types", [None]), ("common_configuration", CONF_MODEL_LIST),              ("common_imports", IMPORT_MODEL_LIST), +            ("common_geo", GEO_MODEL_LIST)          )          namelist = zip_file.namelist()          for current_dir, model_list in DIRS: diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py index 7be176c05..efa0188ff 100644 --- a/ishtar_common/tests.py +++ b/ishtar_common/tests.py @@ -53,7 +53,7 @@ from ishtar_common.apps import admin_site  from ishtar_common.serializers import type_serialization, \      SERIALIZATION_VERSION, serialization_info, \      restore_serialized, conf_serialization, CONF_MODEL_LIST, \ -    importer_serialization, IMPORT_MODEL_LIST +    importer_serialization, IMPORT_MODEL_LIST, geo_serialization, GEO_MODEL_LIST  from ishtar_common.utils import post_save_geo, update_data, move_dict_data, \      rename_and_simplify_media_name, try_fix_file @@ -661,6 +661,26 @@ class SerializationTest(TestCase):          self.create_default_importer()          self.generic_serialization_test(importer_serialization) +    def create_geo_default(self): +        s = models.State.objects.create(label="test", number="999") +        d = models.Department.objects.create(label="test", number="999", +                                             state=s) +        t1 = models.Town.objects.create( +            name="Test town", +            center="SRID=4326;POINT(-44.3 60.1)", +            numero_insee="12345", departement=d +        ) +        t2 = models.Town.objects.create( +            name="Test town 2", +            center="SRID=4326;POINT(-44.2 60.2)", +            numero_insee="12346", departement=d +        ) +        t2.children.add(t1) + +    def test_geo_serialization(self): +        self.create_geo_default() +        self.generic_serialization_test(geo_serialization) +      def test_serialization_zip(self):          zip_filename = type_serialization(archive=True)          # only check the validity of the zip, the type content is tested above @@ -723,40 +743,49 @@ class SerializationTest(TestCase):          self.assertTrue(OperationRT.objects.filter(              inverse_relation__isnull=False).count()) -    def test_conf_restore(self): -        values = self.create_default_conf() +    def generic_restore_test_genzip(self, model_list, serialization):          current_number = {} -        for model in CONF_MODEL_LIST: +        for model in model_list:              current_number[model.__name__] = model.objects.count() -        zip_filename = conf_serialization(archive=True) -        os.remove(values["document_template"].template.path) +        zip_filename = serialization(archive=True) +        return current_number, zip_filename +    def generic_restore_test(self, zip_filename, current_number, model_list):          restore_serialized(zip_filename, delete_existing=True) -        for model in CONF_MODEL_LIST: +        for model in model_list:              previous_nb = current_number[model.__name__]              current_nb = model.objects.count()              self.assertEqual(                  previous_nb, current_nb,                  msg="Restore for model {} failed. Initial: {}, restored: "                      "{}.".format(model.__name__, previous_nb, current_nb)) + +    def test_conf_restore(self): +        values = self.create_default_conf() +        current_number, zip_filename = self.generic_restore_test_genzip( +            CONF_MODEL_LIST, conf_serialization) +        os.remove(values["document_template"].template.path) +        self.generic_restore_test(zip_filename, current_number, CONF_MODEL_LIST)          self.assertTrue(              os.path.isfile(values["document_template"].template.path)          )      def test_importer_restore(self):          self.create_default_importer() -        current_number = {} -        for model in IMPORT_MODEL_LIST: -            current_number[model.__name__] = model.objects.count() -        zip_filename = importer_serialization(archive=True) -        restore_serialized(zip_filename, delete_existing=True) -        for model in IMPORT_MODEL_LIST: -            previous_nb = current_number[model.__name__] -            current_nb = model.objects.count() -            self.assertEqual( -                previous_nb, current_nb, -                msg="Restore for model {} failed. Initial: {}, restored: " -                    "{}.".format(model.__name__, previous_nb, current_nb)) +        current_number, zip_filename = self.generic_restore_test_genzip( +            IMPORT_MODEL_LIST, importer_serialization) +        self.generic_restore_test(zip_filename, current_number, +                                  IMPORT_MODEL_LIST) + +    def test_geo_restore(self): +        self.create_geo_default() +        self.assertTrue(models.Town.objects.get(numero_insee="12345").center) +        current_number, zip_filename = self.generic_restore_test_genzip( +            GEO_MODEL_LIST, geo_serialization) +        self.generic_restore_test(zip_filename, current_number, +                                  GEO_MODEL_LIST) +        # no geo restore +        self.assertFalse(models.Town.objects.get(numero_insee="12345").center)  class AccessControlTest(TestCase):  | 
