summaryrefslogtreecommitdiff
path: root/ishtar_common
diff options
context:
space:
mode:
Diffstat (limited to 'ishtar_common')
-rw-r--r--ishtar_common/serializers.py37
-rw-r--r--ishtar_common/tests.py30
2 files changed, 52 insertions, 15 deletions
diff --git a/ishtar_common/serializers.py b/ishtar_common/serializers.py
index f71109374..6f3cf095d 100644
--- a/ishtar_common/serializers.py
+++ b/ishtar_common/serializers.py
@@ -1,7 +1,9 @@
import datetime
import json
+import importlib
import os
import tempfile
+import zipfile
from rest_framework import serializers
from zipfile import ZipFile
@@ -21,15 +23,22 @@ class PublicSerializer(serializers.BaseSerializer):
SERIALIZATION_VERSION = "1.0"
+def get_model_from_filename(filename):
+ filename = filename.split(".")[0] # remove extension
+ module_name, model_name = filename.split("__")
+ module = importlib.import_module(module_name + ".models")
+ return getattr(module, model_name)
+
+
def serialization_info():
site = Site.objects.get_current()
- return json.dumps({
+ return {
"serialize-version": SERIALIZATION_VERSION,
"ishtar-version": get_version(),
"domain": site.domain,
"name": site.name,
"date": datetime.datetime.now().isoformat()
- }, indent=2)
+ }
def type_serialization(archive=False, return_empty_types=False,
@@ -76,7 +85,9 @@ def type_serialization(archive=False, return_empty_types=False,
base_filename = "info.json"
filename = tmpdirname + os.sep + base_filename
with open(filename, "w") as json_file:
- json_file.write(serialization_info())
+ json_file.write(
+ json.dumps(serialization_info(), indent=2)
+ )
current_zip.write(filename, arcname=base_filename)
for model_name in result:
@@ -89,5 +100,21 @@ def type_serialization(archive=False, return_empty_types=False,
return archive_name
-def restore_serialized(archive_name):
- pass
+def restore_serialized(archive_name, delete_existing=False):
+ with zipfile.ZipFile(archive_name, "r") as zip_file:
+ # check version
+ info = json.loads(zip_file.read("info.json").decode("utf-8"))
+ if info["serialize-version"] != SERIALIZATION_VERSION:
+ raise ValueError(
+ "This dump version: {} is not managed by this Ishtar "
+ "installation".format(info["serialize-version"])
+ )
+
+ # restore types
+ for json_filename in zip_file.namelist():
+ path = json_filename.split(os.sep)
+ if len(path) != 2 or path[0] != "types":
+ continue
+ model = get_model_from_filename(path[-1])
+ if delete_existing:
+ model.objects.all().delete()
diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py
index 25be97692..03c2ad03a 100644
--- a/ishtar_common/tests.py
+++ b/ishtar_common/tests.py
@@ -25,8 +25,9 @@ import io
import json
import os
import shutil
-from io import StringIO
+import tempfile
import zipfile
+from io import StringIO
from django.apps import apps
@@ -49,7 +50,8 @@ from ishtar_common import models
from ishtar_common import views
from ishtar_common.apps import admin_site
from ishtar_common.serializers import type_serialization, \
- SERIALIZATION_VERSION
+ SERIALIZATION_VERSION, get_model_from_filename, serialization_info, \
+ restore_serialized
from ishtar_common.utils import post_save_geo, update_data, move_dict_data, \
rename_and_simplify_media_name, try_fix_file
@@ -613,7 +615,21 @@ class SerializationTest(TestCase):
self.assertIsNone(zip_file.testzip())
info = json.loads(zip_file.read("info.json").decode("utf-8"))
self.assertEqual(info["serialize-version"], SERIALIZATION_VERSION)
- print(info)
+
+ def test_restore_version(self):
+ zip_filename = type_serialization(archive=True)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ with zipfile.ZipFile(zip_filename, "w") as zip_file:
+ base_filename = "info.json"
+ filename = tmpdirname + os.sep + base_filename
+ with open(filename, "w") as json_file:
+ info = serialization_info()
+ info["serialize-version"] = "-42"
+ json_file.write(json.dumps(info, indent=2))
+
+ zip_file.write(filename, arcname=base_filename)
+ with self.assertRaises(ValueError):
+ restore_serialized(zip_filename)
def test_type_restore(self):
zip_filename = type_serialization(archive=True)
@@ -623,15 +639,9 @@ class SerializationTest(TestCase):
path = json_filename.split(os.sep)
if len(path) != 2 or path[0] != "types":
continue
- filename = path[-1].split(".")[0]
- module_name, model_name = filename.split("__")
- module = importlib.import_module(module_name + ".models")
- model = getattr(module, model_name)
+ model = get_model_from_filename(path[-1])
initial_count[json_filename] = model.objects.count()
model.objects.all().delete()
- print(initial_count)
-
-
class AccessControlTest(TestCase):