diff options
| -rw-r--r-- | archaeological_operations/tests.py | 118 | ||||
| -rw-r--r-- | ishtar_common/data_importer.py | 93 | ||||
| -rw-r--r-- | ishtar_common/models.py | 16 | ||||
| -rw-r--r-- | ishtar_common/tests.py | 4 | 
4 files changed, 187 insertions, 44 deletions
| diff --git a/archaeological_operations/tests.py b/archaeological_operations/tests.py index d4134693f..f2126a68e 100644 --- a/archaeological_operations/tests.py +++ b/archaeological_operations/tests.py @@ -35,7 +35,7 @@ import models  from archaeological_operations import views  from ishtar_common.models import OrganizationType, Organization, \ -    ImporterType, IshtarUser, TargetKey, IshtarSiteProfile, Town +    ImporterType, IshtarUser, TargetKey, ImporterModel, IshtarSiteProfile, Town  from ishtar_common import forms_common  from ishtar_common.tests import WizardTest, WizardTestFormData as FormData, \ @@ -65,41 +65,22 @@ class ImportOperationTest(TestCase):          self.username, self.password, self.user = create_superuser()          self.ishtar_user = IshtarUser.objects.get(pk=self.user.pk) -    def testMCCImportOperation(self, test=True): -        # MCC opérations -        if self.test_operations is False: -            test = False -        first_ope_nb = models.Operation.objects.count() -        MCC_OPERATION = ImporterType.objects.get(name=u"MCC - Opérations") +    def init_ope_import(self): +        mcc_operation = ImporterType.objects.get(name=u"MCC - Opérations")          mcc_operation_file = open(              settings.ROOT_PATH +              '../archaeological_operations/tests/MCC-operations-example.csv',              'rb')          file_dict = {'imported_file': SimpleUploadedFile(              mcc_operation_file.name, mcc_operation_file.read())} -        post_dict = {'importer_type': MCC_OPERATION.pk, 'skip_lines': 1, +        post_dict = {'importer_type': mcc_operation.pk, 'skip_lines': 1,                       "encoding": 'utf-8'} -        form = forms_common.NewImportForm(data=post_dict, files=file_dict, -                                          instance=None) +        form = forms_common.NewImportForm(data=post_dict, files=file_dict)          form.is_valid() -        if test: -            self.assertTrue(form.is_valid()) -        impt = form.save(self.ishtar_user) -        target_key_nb = TargetKey.objects.count() -        impt.initialize() -        # new key have to be set -        if test: -            self.assertTrue(TargetKey.objects.count() > target_key_nb) +        return mcc_operation, form -        # first try to import -        impt.importation() -        current_ope_nb = models.Operation.objects.count() -        # no new operation imported because of a missing connection for -        # operation_type value -        if test: -            self.assertTrue(current_ope_nb == first_ope_nb) - -        # doing manualy connections +    def init_ope_targetkey(self, imp): +        # doing manually connections          tg = TargetKey.objects.filter(target__target='operation_type'                                        ).order_by('-pk').all()[0]          tg.value = models.OperationType.objects.get( @@ -107,18 +88,44 @@ class ImportOperationTest(TestCase):          tg.is_set = True          tg.save() -        target = TargetKey.objects.get(key='gallo-romain') +        target = TargetKey.objects.get(key='gallo-romain', +                                       associated_import=imp)          gallo = models.Period.objects.get(txt_idx='gallo-roman')          target.value = gallo.pk          target.is_set = True          target.save() -        target = TargetKey.objects.get(key='age-du-fer') +        target = TargetKey.objects.get(key='age-du-fer', +                                       associated_import=imp)          iron = models.Period.objects.get(txt_idx='iron_age')          target.value = iron.pk          target.is_set = True          target.save() +    def test_mcc_import_operation(self, test=True): +        # MCC opérations +        if self.test_operations is False: +            test = False +        first_ope_nb = models.Operation.objects.count() +        importer, form = self.init_ope_import() +        if test: +            self.assertTrue(form.is_valid()) +        impt = form.save(self.ishtar_user) +        target_key_nb = TargetKey.objects.count() +        impt.initialize() +        # new key have to be set +        if test: +            self.assertTrue(TargetKey.objects.count() > target_key_nb) + +        # first try to import +        impt.importation() +        current_ope_nb = models.Operation.objects.count() +        # no new operation imported because of a missing connection for +        # operation_type value +        if test: +            self.assertTrue(current_ope_nb == first_ope_nb) +        self.init_ope_targetkey(imp=impt) +          impt.importation()          if not test:              return @@ -131,8 +138,9 @@ class ImportOperationTest(TestCase):          self.assertTrue(last_ope.code_patriarche == 4200)          self.assertTrue(last_ope.operation_type.txt_idx == 'prog_excavation')          self.assertEqual(last_ope.periods.count(), 2) -        periods = last_ope.periods.all() -        self.assertTrue(iron in periods and gallo in periods) +        periods = [period.txt_idx for period in last_ope.periods.all()] +        self.assertIn('iron_age', periods) +        self.assertIn('gallo-roman', periods)          # a second importation will be not possible: no two same patriarche          # code @@ -141,10 +149,56 @@ class ImportOperationTest(TestCase):          self.assertTrue(last_ope ==                          models.Operation.objects.order_by('-pk').all()[0]) +    def test_model_limitation(self): +        importer, form = self.init_ope_import() +        importer.created_models.clear() +        impt = form.save(self.ishtar_user) +        impt.initialize() +        self.init_ope_targetkey(imp=impt) + +        # no model defined in created_models: normal import +        init_ope_number = models.Operation.objects.count() +        impt.importation() +        current_ope_nb = models.Operation.objects.count() +        self.assertEqual(current_ope_nb, init_ope_number + 1) + +        last_ope = models.Operation.objects.order_by('-pk').all()[0] +        last_ope.delete() + +        importer, form = self.init_ope_import() +        # add an inadequate model to make created_models non empty +        importer.created_models.clear() +        importer.created_models.add(ImporterModel.objects.get( +            klass='ishtar_common.models.Organization' +        )) +        impt = form.save(self.ishtar_user) +        impt.initialize() +        self.init_ope_targetkey(imp=impt) + +        # no imports +        impt.importation() +        current_ope_nb = models.Operation.objects.count() +        self.assertEqual(current_ope_nb, init_ope_number) + +        importer, form = self.init_ope_import() +        # add operation model to allow creation +        importer.created_models.clear() +        importer.created_models.add(ImporterModel.objects.get( +            klass='archaeological_operations.models.Operation' +        )) +        impt = form.save(self.ishtar_user) +        impt.initialize() +        self.init_ope_targetkey(imp=impt) + +        # import of operations +        impt.importation() +        current_ope_nb = models.Operation.objects.count() +        self.assertEqual(current_ope_nb, init_ope_number + 1) +      def testMCCImportParcels(self, test=True):          if self.test_operations is False:              test = False -        self.testMCCImportOperation(test=False) +        self.test_mcc_import_operation(test=False)          old_nb = models.Parcel.objects.count()          MCC_PARCEL = ImporterType.objects.get(name=u"MCC - Parcelles")          mcc_file = open( diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index de4883c69..79259b76d 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -29,6 +29,7 @@ import zipfile  from django.conf import settings  from django.contrib.auth.models import User +from django.core.exceptions import ImproperlyConfigured  from django.core.files import File  from django.db import IntegrityError, DatabaseError, transaction  from django.template.defaultfilters import slugify @@ -613,6 +614,8 @@ class Importer(object):      OBJECT_CLS = None      IMPORTED_LINE_FIELD = None      UNICITY_KEYS = [] +    # if set only models inside this list can be created +    MODEL_CREATION_LIMIT = []      EXTRA_DEFAULTS = {}      DEFAULTS = {}      ERRORS = { @@ -626,10 +629,19 @@ class Importer(object):          'no_data': _(u"No data provided"),          'value_required': _(u"Value is required"),          'not_enough_cols': _(u"At least %d columns must be filled"), -        'regex_not_match': _(u"The regexp doesn't match.") +        'regex_not_match': _(u"The regexp doesn't match."), +        'improperly_configured': _( +            u"Force creation is set for model {} but this model is not in the " +            u"list of model allowed to be created."), +        'does_not_exist_in_db': _(u"{} with values {} doesn't exist in the " +          u"database. Create it first or fix your source file"),      }      def _create_models(self, force=False): +        """ +        Create a db config from a hardcoded import. +        Not useful anymore? +        """          from ishtar_common import models          q = models.ImporterType.objects.filter(slug=self.SLUG)          if not force and (not self.SLUG or q.count()): @@ -1009,11 +1021,32 @@ class Importer(object):                      if k not in formater.through_unicity_keys \                         and k != 'defaults':                          data['defaults'][k] = data.pop(k) +            created = False              if '__force_new' in data: +                if self.MODEL_CREATION_LIMIT and \ +                                through_cls not in self.MODEL_CREATION_LIMIT: +                    raise ImproperlyConfigured( +                        unicode(self.ERRORS[ 'improperly_configured']).format( +                            through_cls))                  created = data.pop('__force_new')                  t_obj = through_cls.objects.create(**data)              else: -                t_obj, created = through_cls.objects.get_or_create(**data) +                if not self.MODEL_CREATION_LIMIT or \ +                        through_cls in self.MODEL_CREATION_LIMIT: +                    t_obj, created = through_cls.objects.get_or_create(**data) +                else: +                    get_data = data.copy() +                    if 'defaults' in get_data: +                        get_data.pop('defaults') +                    try: +                        t_obj = through_cls.objects.get(**get_data) +                    except through_cls.DoesNotExist: +                        values = u", ".join( +                            [u"{}: {}".format(k, get_data[k]) for k in get_data] +                        ) +                        raise ImporterError( +                            unicode(self.ERRORS['does_not_exist_in_db'] +                                    ).format(through_cls, values))              if not created and 'defaults' in data:                  for k in data['defaults']:                      setattr(t_obj, k, data['defaults'][k]) @@ -1247,6 +1280,12 @@ class Importer(object):                              new_created[attribute].append(key)                              has_values = bool([1 for k in v if v[k]])                              if has_values: +                                if self.MODEL_CREATION_LIMIT and \ +                                        model not in self.MODEL_CREATION_LIMIT: +                                    raise ImproperlyConfigured( +                                        unicode( +                                            self.ERRORS['improperly_configured'] +                                        ).format(model))                                  v = model.objects.create(**v)                              else:                                  continue @@ -1255,14 +1294,32 @@ class Importer(object):                              extra_fields = {}                              # "File" type is a temp object and can be different                              # for the same filename - it must be treated -                            # separatly +                            # separately                              for field in model._meta.fields:                                  k = field.name -                                # attr_class est un attribut de FileField +                                # attr_class is a FileField attribute                                  if hasattr(field, 'attr_class') and k in v:                                      extra_fields[k] = v.pop(k) -                            v, created = model.objects.get_or_create( -                                **v) +                            if not self.MODEL_CREATION_LIMIT or \ +                                    model in self.MODEL_CREATION_LIMIT: +                                v, created = model.objects.get_or_create( +                                    **v) +                            else: +                                get_v = v.copy() +                                if 'defaults' in get_v: +                                    get_v.pop('defaults') +                                try: +                                    v = model.objects.get(**get_v) +                                except model.DoesNotExist: +                                    values = u", ".join( +                                        [u"{}: {}".format(k, get_v[k]) +                                         for k in get_v] +                                    ) +                                    raise ImporterError( +                                        unicode( +                                            self.ERRORS[ +                                                'does_not_exist_in_db'] +                                        ).format(model, values))                              changed = False                              for k in extra_fields.keys():                                  if extra_fields[k]: @@ -1336,6 +1393,7 @@ class Importer(object):                      'history_modifier': create_dict.pop('history_modifier')                  }) +            created = False              try:                  try:                      dct = create_dict.copy() @@ -1348,6 +1406,11 @@ class Importer(object):                              return None, created                          new_dct = defaults.copy()                          new_dct.update(dct) +                        if self.MODEL_CREATION_LIMIT and \ +                                cls not in self.MODEL_CREATION_LIMIT: +                            raise ImproperlyConfigured( +                                unicode(self.ERRORS[ 'improperly_configured'] +                                        ).format(cls))                          obj = cls.objects.create(**new_dct)                      else:                          # manage UNICITY_KEYS - only level 1 @@ -1356,9 +1419,21 @@ class Importer(object):                                  if k not in self.UNICITY_KEYS \                                     and k != 'defaults':                                      defaults[k] = dct.pop(k) - -                        dct['defaults'] = defaults.copy() -                        obj, created = cls.objects.get_or_create(**dct) +                        if not self.MODEL_CREATION_LIMIT or \ +                                cls in self.MODEL_CREATION_LIMIT: +                            dct['defaults'] = defaults.copy() +                            obj, created = cls.objects.get_or_create(**dct) +                        else: +                            try: +                                obj = cls.objects.get(**dct) +                                dct['defaults'] = defaults.copy() +                            except cls.DoesNotExist: +                                values = u", ".join( +                                    [u"{}: {}".format(k, dct[k]) for k in dct] +                                ) +                                raise ImporterError( +                                    unicode(self.ERRORS['does_not_exist_in_db'] +                                            ).format(cls, values))                          if not created and not path and self.UNICITY_KEYS:                              changed = False diff --git a/ishtar_common/models.py b/ishtar_common/models.py index 6cf5bff7d..c27f9cc29 100644 --- a/ishtar_common/models.py +++ b/ishtar_common/models.py @@ -35,7 +35,8 @@ import zipfile  from django.conf import settings  from django.core.cache import cache -from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.exceptions import ObjectDoesNotExist, ValidationError, \ +    SuspiciousOperation  from django.core.files import File  from django.core.files.uploadedfile import SimpleUploadedFile  from django.core.validators import validate_slug @@ -1723,9 +1724,16 @@ def get_model_fields(model):  def import_class(full_path_classname): +    """ +    Return the model class from the full path +    TODO: add a white list for more security +    """      mods = full_path_classname.split('.')      if len(mods) == 1:          mods = ['ishtar_common', 'models', mods[0]] +    elif 'models' not in mods: +        raise SuspiciousOperation( +            u"Try to import a non model from a string")      module = import_module('.'.join(mods[:-1]))      return getattr(module, mods[-1]) @@ -1820,9 +1828,13 @@ class ImporterType(models.Model):          UNICITY_KEYS = []          if self.unicity_keys:              UNICITY_KEYS = [un.strip() for un in self.unicity_keys.split(';')] +        MODEL_CREATION_LIMIT = [] +        for modls in self.created_models.all(): +            MODEL_CREATION_LIMIT.append(import_class(modls.klass))          args = {'OBJECT_CLS': OBJECT_CLS, 'DESC': self.description,                  'DEFAULTS': DEFAULTS, 'LINE_FORMAT': LINE_FORMAT, -                'UNICITY_KEYS': UNICITY_KEYS} +                'UNICITY_KEYS': UNICITY_KEYS, +                'MODEL_CREATION_LIMIT': MODEL_CREATION_LIMIT}          name = str(''.join(              x for x in slugify(self.name).replace('-', ' ').title()              if not x.isspace())) diff --git a/ishtar_common/tests.py b/ishtar_common/tests.py index a512dcc07..a3fa62ce7 100644 --- a/ishtar_common/tests.py +++ b/ishtar_common/tests.py @@ -27,6 +27,7 @@ from django.contrib.contenttypes.models import ContentType  from django.core.cache import cache  from django.core.exceptions import ValidationError  from django.core.files.base import File as DjangoFile +from django.core.files.uploadedfile import SimpleUploadedFile  from django.core.management import call_command  from django.core.urlresolvers import reverse  from django.template.defaultfilters import slugify @@ -35,6 +36,7 @@ from django.test.client import Client  from django.test.simple import DjangoTestSuiteRunner  from ishtar_common import models +from ishtar_common import forms_common  from ishtar_common.utils import post_save_point  """ @@ -275,7 +277,7 @@ class AdminGenTypeTest(TestCase):          models.OrganizationType, models.PersonType, models.TitleType,          models.AuthorType, models.SourceType, models.OperationType,          models.SpatialReferenceSystem, models.Format, models.SupportType] -    models_with_data = gen_models  + [models.ImporterModel] +    models_with_data = gen_models + [models.ImporterModel]      models = models_with_data      module_name = 'ishtar_common' | 
