summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ishtar_common/serializers.py81
-rw-r--r--ishtar_common/tests.py67
2 files changed, 102 insertions, 46 deletions
diff --git a/ishtar_common/serializers.py b/ishtar_common/serializers.py
index 4867b58f6..6d25efdb6 100644
--- a/ishtar_common/serializers.py
+++ b/ishtar_common/serializers.py
@@ -103,14 +103,12 @@ def archive_serialization(result, archive_dir=None, archive=False,
return archive_name
-def type_serialization(archive=False, return_empty_types=False,
- archive_name=None):
+def generic_get_results(model_list, dirname, no_geo=True):
result = OrderedDict()
- for model in apps.get_models():
- if not isinstance(model(), models.GeneralType):
- continue
+ for model in model_list:
model_name = model.__name__
model_name = str(model.__module__).split(".")[0] + "__" + model_name
+
base_q = model.objects
q = base_q
recursion = None
@@ -118,16 +116,25 @@ def type_serialization(archive=False, return_empty_types=False,
recursion = "parent"
elif hasattr(model, "inverse_relation"):
recursion = "inverse_relation"
+ elif hasattr(model, "children"):
+ recursion = "children__id"
if recursion:
q = q.filter(**{recursion + "__isnull": True})
- key = ("types", model_name)
+
+ key = (dirname, model_name)
result[key] = serialize(
- "json", q.all(), indent=2,
- use_natural_foreign_keys=True, use_natural_primary_keys=True
+ "json", q.all(),
+ indent=2,
+ use_natural_foreign_keys=True, use_natural_primary_keys=True,
)
+
if recursion:
serialized = [item["id"] for item in q.values("id").all()]
- q = base_q.filter(**{recursion + "_id__in": serialized}
+ recursion_in = recursion
+ if not recursion.endswith("_id"):
+ recursion_in += "_id"
+ recursion_in += "__in"
+ q = base_q.filter(**{recursion_in: serialized}
).exclude(id__in=serialized)
while q.count():
v = serialize(
@@ -137,7 +144,7 @@ def type_serialization(archive=False, return_empty_types=False,
new_result += json.loads(v)
result[key] = json.dumps(new_result, indent=2)
serialized += [item["id"] for item in q.values("id").all()]
- q = base_q.filter(**{recursion + "_id__in": serialized}
+ q = base_q.filter(**{recursion_in: serialized}
).exclude(id__in=serialized)
# managed circular
q = base_q.exclude(id__in=serialized)
@@ -153,27 +160,20 @@ def type_serialization(archive=False, return_empty_types=False,
new_result += result_cleaned
new_result += result_to_add
result[key] = json.dumps(new_result, indent=2)
- return archive_serialization(result, archive_dir="types", archive=archive,
- return_empty_types=return_empty_types,
- archive_name=archive_name)
-
-def generic_get_results(model_list, dirname):
- result = OrderedDict()
- for model in model_list:
- model_name = model.__name__
- model_name = str(model.__module__).split(".")[0] + "__" + model_name
- key = (dirname, model_name)
- result[key] = serialize(
- "json", model.objects.all(),
- indent=2,
- use_natural_foreign_keys=True, use_natural_primary_keys=True,
- )
+ excluded_fields = []
if hasattr(model, "SERIALIZATION_EXCLUDE"):
+ excluded_fields = list(model.SERIALIZATION_EXCLUDE)
+ if no_geo:
+ excluded_fields += ["center", "limit"] + [
+ field.name for field in models.GeoItem._meta.get_fields()
+ ]
+ if excluded_fields:
new_result = json.loads(result[key])
for idx in range(len(new_result)):
- for excluded_field in model.SERIALIZATION_EXCLUDE:
- new_result[idx]["fields"].pop(excluded_field)
+ for excluded_field in excluded_fields:
+ if excluded_field in new_result[idx]["fields"]:
+ new_result[idx]["fields"].pop(excluded_field)
result[key] = json.dumps(new_result, indent=2)
return result
@@ -201,6 +201,18 @@ def generic_archive_files(model_list, archive_name=None):
return archive_name
+def type_serialization(archive=False, return_empty_types=False,
+ archive_name=None):
+ TYPE_MODEL_LIST = [
+ model for model in apps.get_models()
+ if isinstance(model(), models.GeneralType)
+ ]
+ result = generic_get_results(TYPE_MODEL_LIST, "types")
+ return archive_serialization(result, archive_dir="types", archive=archive,
+ return_empty_types=return_empty_types,
+ archive_name=archive_name)
+
+
CONF_MODEL_LIST = [
models.IshtarSiteProfile, models.GlobalVar, models.CustomForm,
models.ExcludedField, models.JsonDataSection, models.JsonDataField,
@@ -243,6 +255,20 @@ def importer_serialization(archive=False, return_empty_types=False,
return full_archive
+GEO_MODEL_LIST = [
+ models.State, models.Department, models.Town
+]
+
+
+def geo_serialization(archive=False, return_empty_types=False,
+ archive_name=None, no_geo=True):
+ result = generic_get_results(GEO_MODEL_LIST, "common_geo", no_geo=no_geo)
+ full_archive = archive_serialization(
+ result, archive_dir="common_geo", archive=archive,
+ return_empty_types=return_empty_types, archive_name=archive_name)
+ return full_archive
+
+
def restore_serialized(archive_name, delete_existing=False):
with zipfile.ZipFile(archive_name, "r") as zip_file:
# check version
@@ -256,6 +282,7 @@ def restore_serialized(archive_name, delete_existing=False):
DIRS = (
("types", [None]), ("common_configuration", CONF_MODEL_LIST),
("common_imports", IMPORT_MODEL_LIST),
+ ("common_geo", GEO_MODEL_LIST)
)
namelist = zip_file.namelist()
for current_dir, model_list in DIRS:
diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py
index 7be176c05..efa0188ff 100644
--- a/ishtar_common/tests.py
+++ b/ishtar_common/tests.py
@@ -53,7 +53,7 @@ from ishtar_common.apps import admin_site
from ishtar_common.serializers import type_serialization, \
SERIALIZATION_VERSION, serialization_info, \
restore_serialized, conf_serialization, CONF_MODEL_LIST, \
- importer_serialization, IMPORT_MODEL_LIST
+ importer_serialization, IMPORT_MODEL_LIST, geo_serialization, GEO_MODEL_LIST
from ishtar_common.utils import post_save_geo, update_data, move_dict_data, \
rename_and_simplify_media_name, try_fix_file
@@ -661,6 +661,26 @@ class SerializationTest(TestCase):
self.create_default_importer()
self.generic_serialization_test(importer_serialization)
+ def create_geo_default(self):
+ s = models.State.objects.create(label="test", number="999")
+ d = models.Department.objects.create(label="test", number="999",
+ state=s)
+ t1 = models.Town.objects.create(
+ name="Test town",
+ center="SRID=4326;POINT(-44.3 60.1)",
+ numero_insee="12345", departement=d
+ )
+ t2 = models.Town.objects.create(
+ name="Test town 2",
+ center="SRID=4326;POINT(-44.2 60.2)",
+ numero_insee="12346", departement=d
+ )
+ t2.children.add(t1)
+
+ def test_geo_serialization(self):
+ self.create_geo_default()
+ self.generic_serialization_test(geo_serialization)
+
def test_serialization_zip(self):
zip_filename = type_serialization(archive=True)
# only check the validity of the zip, the type content is tested above
@@ -723,40 +743,49 @@ class SerializationTest(TestCase):
self.assertTrue(OperationRT.objects.filter(
inverse_relation__isnull=False).count())
- def test_conf_restore(self):
- values = self.create_default_conf()
+ def generic_restore_test_genzip(self, model_list, serialization):
current_number = {}
- for model in CONF_MODEL_LIST:
+ for model in model_list:
current_number[model.__name__] = model.objects.count()
- zip_filename = conf_serialization(archive=True)
- os.remove(values["document_template"].template.path)
+ zip_filename = serialization(archive=True)
+ return current_number, zip_filename
+ def generic_restore_test(self, zip_filename, current_number, model_list):
restore_serialized(zip_filename, delete_existing=True)
- for model in CONF_MODEL_LIST:
+ for model in model_list:
previous_nb = current_number[model.__name__]
current_nb = model.objects.count()
self.assertEqual(
previous_nb, current_nb,
msg="Restore for model {} failed. Initial: {}, restored: "
"{}.".format(model.__name__, previous_nb, current_nb))
+
+ def test_conf_restore(self):
+ values = self.create_default_conf()
+ current_number, zip_filename = self.generic_restore_test_genzip(
+ CONF_MODEL_LIST, conf_serialization)
+ os.remove(values["document_template"].template.path)
+ self.generic_restore_test(zip_filename, current_number, CONF_MODEL_LIST)
self.assertTrue(
os.path.isfile(values["document_template"].template.path)
)
def test_importer_restore(self):
self.create_default_importer()
- current_number = {}
- for model in IMPORT_MODEL_LIST:
- current_number[model.__name__] = model.objects.count()
- zip_filename = importer_serialization(archive=True)
- restore_serialized(zip_filename, delete_existing=True)
- for model in IMPORT_MODEL_LIST:
- previous_nb = current_number[model.__name__]
- current_nb = model.objects.count()
- self.assertEqual(
- previous_nb, current_nb,
- msg="Restore for model {} failed. Initial: {}, restored: "
- "{}.".format(model.__name__, previous_nb, current_nb))
+ current_number, zip_filename = self.generic_restore_test_genzip(
+ IMPORT_MODEL_LIST, importer_serialization)
+ self.generic_restore_test(zip_filename, current_number,
+ IMPORT_MODEL_LIST)
+
+ def test_geo_restore(self):
+ self.create_geo_default()
+ self.assertTrue(models.Town.objects.get(numero_insee="12345").center)
+ current_number, zip_filename = self.generic_restore_test_genzip(
+ GEO_MODEL_LIST, geo_serialization)
+ self.generic_restore_test(zip_filename, current_number,
+ GEO_MODEL_LIST)
+ # no geo restore
+ self.assertFalse(models.Town.objects.get(numero_insee="12345").center)
class AccessControlTest(TestCase):