diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
-rw-r--r-- | ishtar_common/data_importer.py | 519 |
1 files changed, 327 insertions, 192 deletions
diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 2ab5ba28f..b88fbff2f 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# Copyright (C) 2013-2015 Étienne Loks <etienne.loks_AT_peacefrogsDOTnet> +# Copyright (C) 2013-2017 Étienne Loks <etienne.loks_AT_peacefrogsDOTnet> # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -36,11 +36,28 @@ from django.db.models import Q from django.template.defaultfilters import slugify from django.utils.translation import ugettext_lazy as _ +from ishtar_common.utils import get_all_field_names + + NEW_LINE_BREAK = '#####@@@#####' RE_FILTER_CEDEX = re.compile("(.*) *(?: *CEDEX|cedex|Cedex|Cédex|cédex *\d*)") +def post_importer_action(func): + def wrapper(self, context, value): + return func(self, context, value) + wrapper.importer_trigger = 'post' + return wrapper + + +def pre_importer_action(func): + def wrapper(self, context, value): + return func(self, context, value) + wrapper.importer_trigger = 'pre' + return wrapper + + class ImportFormater(object): def __init__(self, field_name, formater=None, required=True, through=None, through_key=None, through_dict=None, @@ -74,20 +91,20 @@ class ImportFormater(object): self.force_new = force_new self.label = label - def reinit_db_target(self, db_target, nb=0): + def reinit_db_target(self, db_target, nb=0, user=None): if not self.formater: return if type(db_target) in (list, tuple): db_target = db_target[nb] if type(self.formater) not in (list, tuple): self.formater.db_target = db_target - self.formater.init_db_target() + self.formater.init_db_target(user=user) else: for idx, formater in enumerate(self.formater): formater.db_target = db_target - formater.init_db_target() + formater.init_db_target(user=user) - def init_db_target(self): + def init_db_target(self, user=None): pass def __unicode__(self): @@ -100,7 +117,7 @@ class ImportFormater(object): return def init(self, vals, output=None, choose_default=False, - import_instance=None): + import_instance=None, user=None): try: lst = iter(self.formater) except TypeError: @@ -109,7 +126,8 @@ class ImportFormater(object): if formater: formater.check(vals, output, self.comment, choose_default=choose_default, - import_instance=import_instance) + import_instance=import_instance, + user=user) def post_process(self, obj, context, value, owner=None): raise NotImplemented() @@ -135,12 +153,39 @@ class Formater(object): return value def check(self, values, output=None, comment='', choose_default=False, - import_instance=None): + import_instance=None, user=None): return - def init_db_target(self): + def init_db_target(self, user=None): pass + def _base_target_filter(self, user=None): + # set for all users + q_or = ( + Q(associated_import__isnull=True) & + Q(associated_user__isnull=True) & + Q(associated_group__isnull=True) + ) + if hasattr(self, 'import_instance') and self.import_instance: + # set for current import + q_or = q_or | Q(associated_import=self.import_instance) + if self.import_instance.associated_group: + # set for associated group + q_or = q_or | Q( + associated_group=self.import_instance.associated_group) + if user: + # set for current user + q_or = q_or | Q(associated_user=user) + return q_or + + def get_db_target_query(self, user=None): + if not self.db_target: + return + q = self.db_target.keys.filter(is_set=True) + q_or = self._base_target_filter(user) + q = q.filter(q_or) + return q + class ChoiceChecker(object): def report_new(self, comment): @@ -290,13 +335,12 @@ class StrChoiceFormater(Formater, ChoiceChecker): self.equiv_dict[value] = v self.init_db_target() - def init_db_target(self): + def init_db_target(self, user=None): if not self.db_target: return - q = self.db_target.keys.filter(is_set=True) - if self.import_instance: - q = q.filter(Q(associated_import=self.import_instance) | - Q(associated_import__isnull=True)) + + q = self.get_db_target_query(user) + for target_key in q.all(): key = target_key.key if not self.strict: @@ -330,7 +374,10 @@ class StrChoiceFormater(Formater, ChoiceChecker): return msgstr, idx def check(self, values, output=None, comment='', choose_default=False, - import_instance=None): + import_instance=None, user=None): + self.init_db_target(user) + + """ from ishtar_common.models import TargetKey if self.db_target: q = {'target': self.db_target, @@ -343,12 +390,12 @@ class StrChoiceFormater(Formater, ChoiceChecker): if hasattr(value, 'pk'): value = value.pk q['value'] = value - with transaction.commit_on_success(): + with transaction.atomic(): try: t, created = TargetKey.objects.get_or_create(**q) except IntegrityError: pass - + """ if (not output or output == 'silent') and not choose_default: return if self.many_split: @@ -384,7 +431,7 @@ class StrChoiceFormater(Formater, ChoiceChecker): if self.model and v: v = self.model.objects.get(pk=v) self.equiv_dict[value] = v - self.add_key(v, value) + self.add_key(v, value, import_instance) self.new_keys[value] = v elif self.create and res == len(self.choices): self.equiv_dict[value] = self.new(base_value) @@ -395,17 +442,17 @@ class StrChoiceFormater(Formater, ChoiceChecker): self.equiv_dict[value] = None if self.equiv_dict[value] and self.db_target: from ishtar_common.models import TargetKey - q = {'target': self.db_target, 'key': value, - 'associated_import': import_instance, - } + q = {'target': self.db_target, 'key': value} query = TargetKey.objects.filter(**q) + query = query.filter(self._base_target_filter(user)) if query.count(): target = query.all()[0] target.value = self.equiv_dict[value] target.is_set = True target.save() else: - with transaction.commit_on_success(): + q['associated_import'] = import_instance + with transaction.atomic(): q['value'] = self.equiv_dict[value] q['is_set'] = True try: @@ -415,11 +462,13 @@ class StrChoiceFormater(Formater, ChoiceChecker): if output == 'db' and self.db_target: from ishtar_common.models import TargetKey for missing in self.missings: - q = {'target': self.db_target, 'key': missing, - 'associated_import': import_instance} - if TargetKey.objects.filter(**q).count(): + q = {'target': self.db_target, 'key': missing} + query = TargetKey.objects.filter(**q) + query = query.filter(self._base_target_filter(user)) + if query.count(): continue - with transaction.commit_on_success(): + with transaction.atomic(): + q['associated_import'] = import_instance try: TargetKey.objects.create(**q) except IntegrityError: @@ -430,7 +479,7 @@ class StrChoiceFormater(Formater, ChoiceChecker): def new(self, value): return - def add_key(self, obj, value): + def add_key(self, obj, value, importer=None): return def format(self, value): @@ -460,20 +509,20 @@ class TypeFormater(StrChoiceFormater): if self.import_instance: for item in model.objects.all(): self.choices.append((item.pk, unicode(item))) - for key in item.get_keys(importer_id=import_instance.pk): + for key in item.get_keys(importer=import_instance): self.equiv_dict[key] = item def prepare(self, value): return slugify(unicode(value).strip()) - def add_key(self, obj, value): - obj.add_key(slugify(value), force=True) + def add_key(self, obj, value, importer=None): + obj.add_key(slugify(value), force=True, importer=importer) def new(self, value): values = copy.copy(self.defaults) values['label'] = value values['txt_idx'] = slugify(value) - if 'order' in self.model._meta.get_all_field_names(): + if 'order' in get_all_field_names(self.model): order = 1 q = self.model.objects.values('order').order_by('-order') if q.count(): @@ -538,15 +587,16 @@ class StrToBoolean(Formater, ChoiceChecker): self.strict = strict self.db_target = db_target self.missings = set() - self.init_db_target() self.match_table = {} self.new_keys = {} self.import_instance = import_instance + self.init_db_target() - def init_db_target(self): + def init_db_target(self, user=None): if not self.db_target: return - for target_key in self.db_target.keys.filter(is_set=True).all(): + q = self.get_db_target_query(user) + for target_key in q.all(): key = self.prepare(target_key.key) if key in self.dct: continue @@ -560,7 +610,7 @@ class StrToBoolean(Formater, ChoiceChecker): return value def check(self, values, output=None, comment='', choose_default=False, - import_instance=None): + import_instance=None, user=None): if (not output or output == 'silent') and not choose_default: return msgstr = comment + u" - " @@ -812,7 +862,7 @@ class Importer(object): self.output = output if not self.history_modifier: if self.import_instance: - self.history_modifier = self.import_instance.user + self.history_modifier = self.import_instance.user.user_ptr else: # import made by the CLI: get the first admin self.history_modifier = User.objects.filter( @@ -825,16 +875,20 @@ class Importer(object): for related_key in item.RELATED_POST_PROCESS: for related in getattr(item, related_key).all(): related.save() + for func, context, value in self._item_post_processing: + getattr(item, func)(context, value) return item - def initialize(self, table, output='silent', choose_default=False): + def initialize(self, table, output='silent', choose_default=False, + user=None): """ copy vals in columns and initialize formaters * output: - - 'silent': no associations - - 'cli': output by command line interface and stocked in the database - - 'db': output on the database with no interactive association + - silent: no associations + - cli: output by command line interface and stocked in the database + - db: output on the database with no interactive association (further exploitation by web interface) + - user: associated user """ assert output in ('silent', 'cli', 'db') vals = [] @@ -859,14 +913,17 @@ class Importer(object): db_targets.append( self.DB_TARGETS["{}-{}".format( idx + 1, field_name)]) - formater.reinit_db_target(db_targets) + formater.reinit_db_target(db_targets, user=user) formater.init(vals[idx], output, choose_default=choose_default, - import_instance=self.import_instance) + import_instance=self.import_instance, + user=user) - def importation(self, table, initialize=True, choose_default=False): + def importation(self, table, initialize=True, choose_default=False, + user=None): if initialize: - self.initialize(table, self.output, choose_default=choose_default) + self.initialize(table, self.output, + choose_default=choose_default, user=user) self._importation(table) def _associate_db_target_to_formaters(self): @@ -990,6 +1047,7 @@ class Importer(object): return self._throughs = [] # list of (formater, value) self._post_processing = [] # list of (formater, value) + self._item_post_processing = [] data = {} # keep in database the raw line for testing purpose @@ -1026,8 +1084,7 @@ class Importer(object): if self.test: return # manage unicity of items (mainly for updates) - if 'history_modifier' in \ - self.OBJECT_CLS._meta.get_all_field_names(): + if 'history_modifier' in get_all_field_names(self.OBJECT_CLS): data['history_modifier'] = self.history_modifier obj, created = self.get_object(self.OBJECT_CLS, data) @@ -1096,7 +1153,7 @@ class Importer(object): for formater, val in self._post_processing: formater.post_process(obj, data, val, owner=self.history_modifier) - obj = self.post_processing(obj, data) + self.post_processing(obj, data) def _row_processing(self, c_row, idx_col, idx_line, val, data): if idx_col >= len(self.line_format): @@ -1239,162 +1296,241 @@ class Importer(object): concat_str=concat_str[idx]) c_row.append(u" ; ".join([v for v in c_values])) + def _get_field_m2m(self, attribute, data, c_path, new_created, + field_object): + """ + Manage and m2m field from raw data + + :param attribute: attribute name + :param data: current data dictionary + :param c_path: attribute path from the main model point of view + :param new_created: dict of forced newly created items to prevent + multiple creation + :param field_object: django field object for this attribute + :return: None + """ + m2ms = [] + + many_values = data.pop(attribute) + if hasattr(field_object, 'rel'): + model = field_object.rel.to + elif hasattr(field_object, 'to'): + model = field_object.to + elif hasattr(field_object, 'model'): + model = field_object.model + if type(many_values) not in (list, tuple): + many_values = [many_values] + + for val in many_values: + if val.__class__ == model: + # the value is a model instance: it is OK! + m2ms.append((attribute, val)) + continue + if type(val) != dict: + # value is not a dict, we don't know what to do with it... + continue + vals = [] + + # contruct many dict for each values + default_dict = {} + + # # get default values + p = [attribute] + if c_path: + p = list(c_path) + p + p = tuple(p) + if p in self._defaults: + for k in self._defaults[p]: + default_dict[k] = self._defaults[p][k] + + # # init with simple values that will be duplicated + for key in val.keys(): + if type(val[key]) not in (list, tuple): + default_dict[key] = val[key] + vals.append(default_dict.copy()) + + # # manage multiple values + for key in val.keys(): + if type(val[key]) in (list, tuple): + for idx, v in enumerate(val[key]): + if len(vals) <= idx: + vals.append(default_dict.copy()) + vals[idx][key] = v + + # check that m2m are not empty + notempty = False + for dct in vals: + for k in dct: + if dct[k] not in ("", None): + notempty = True + break + if not notempty: + continue + + field_names = get_all_field_names(model) + for v in vals: + if 'history_modifier' in field_names: + if 'defaults' not in v: + v['defaults'] = {} + v['defaults']['history_modifier'] = \ + self.history_modifier + m2m_m2ms = [] + c_c_path = c_path[:] + for k in v.keys(): + if k not in field_names: + continue + self.get_field(model, k, v, m2m_m2ms, c_c_path, + new_created) + if '__force_new' in v: + created = v.pop('__force_new') + key = u";".join([u"{}-{}".format(k, v[k]) + for k in sorted(v.keys())]) + # only one forced creation + if attribute in new_created \ + and key in new_created[attribute]: + continue + if attribute not in new_created: + new_created[attribute] = [] + new_created[attribute].append(key) + has_values = bool([1 for k in v if v[k]]) + if has_values: + if self.MODEL_CREATION_LIMIT and \ + model not in self.MODEL_CREATION_LIMIT: + raise self._get_improperly_conf_error(model) + v = model.objects.create(**v) + else: + continue + else: + v['defaults'] = v.get('defaults', {}) + extra_fields = {} + # "File" type is a temp object and can be different + # for the same filename - it must be treated + # separately + for field in model._meta.fields: + k = field.name + # attr_class is a FileField attribute + if hasattr(field, 'attr_class') and k in v: + extra_fields[k] = v.pop(k) + created = False + if not self.MODEL_CREATION_LIMIT or \ + model in self.MODEL_CREATION_LIMIT: + v, created = model.objects.get_or_create( + **v) + else: + get_v = v.copy() + if 'defaults' in get_v: + get_v.pop('defaults') + try: + v = model.objects.get(**get_v) + except model.DoesNotExist: + raise self._get_does_not_exist_in_db_error( + model, get_v) + changed = False + for k in extra_fields.keys(): + if extra_fields[k]: + changed = True + setattr(v, k, extra_fields[k]) + if changed: + v.save() + for att, objs in m2m_m2ms: + if type(objs) not in (list, tuple): + objs = [objs] + for obj in objs: + getattr(v, att).add(obj) + if self.import_instance \ + and hasattr(v, 'imports') and created: + v.imports.add(self.import_instance) + m2ms.append((attribute, v)) + return m2ms + + def _set_importer_trigger(self, cls, attribute, data): + """ + An importer trigger is used. Stock it for later execution and remove + it from current data dict. + + :param cls: current model + :param attribute: attribute name + :param data: current data dictionary + :return: None + """ + func = getattr(cls, attribute) + if func.importer_trigger == 'pre': + pass # TODO + elif func.importer_trigger == 'post': + self._item_post_processing.append([attribute, data, + data[attribute]]) + else: + logger.warning("Unknow importer_trigger '{}' for '{}'".format( + func.importer_trigger, attribute + )) + data.pop(attribute) + def get_field(self, cls, attribute, data, m2ms, c_path, new_created): + """ + Get field from raw data + + :param cls: current model + :param attribute: attribute name + :param data: current data dictionary + :param m2ms: many to many list of tuple: (m2m key, m2m value) + :param c_path: attribute path from the main model point of view + :param new_created: dict of forced newly created items to prevent + multiple creation + :return: None + """ + if hasattr(cls, attribute) and \ + getattr(getattr(cls, attribute), 'importer_trigger', None): + # importer trigger + self._set_importer_trigger(cls, attribute, data) + return try: - field_object, model, direct, m2m = \ - cls._meta.get_field_by_name(attribute) + field_object = cls._meta.get_field(attribute) except FieldDoesNotExist: raise ImporterError(unicode( _(u"Importer configuration error: field \"{}\" does not exist " u"for {}.")).format(attribute, cls._meta.verbose_name)) - if m2m: - many_values = data.pop(attribute) - if hasattr(field_object, 'rel'): - model = field_object.rel.to - elif hasattr(field_object, 'to'): - model = field_object.to - elif hasattr(field_object, 'model'): - model = field_object.model - if type(many_values) not in (list, tuple): - many_values = [many_values] - for val in many_values: - if val.__class__ == model: - m2ms.append((attribute, val)) - elif val.__class__ != model and type(val) == dict: - vals = [] - - # contruct many dict for each values - default_dict = {} - - # # get default values - p = [attribute] - if c_path: - p = list(c_path) + p - p = tuple(p) - if p in self._defaults: - for k in self._defaults[p]: - default_dict[k] = self._defaults[p][k] - # # init with simple values that will be duplicated - for key in val.keys(): - if type(val[key]) not in (list, tuple): - default_dict[key] = val[key] - vals.append(default_dict.copy()) - # # manage multiple values - for key in val.keys(): - if type(val[key]) in (list, tuple): - for idx, v in enumerate(val[key]): - if len(vals) <= idx: - vals.append(default_dict.copy()) - vals[idx][key] = v - - # check that m2m are not empty - notempty = False - for dct in vals: - for k in dct: - if dct[k] not in ("", None): - notempty = True - break - if not notempty: - continue - - field_names = model._meta.get_all_field_names() - for v in vals: - if 'history_modifier' in field_names: - if 'defaults' not in v: - v['defaults'] = {} - v['defaults']['history_modifier'] = \ - self.history_modifier - m2m_m2ms = [] - c_c_path = c_path[:] - for k in v.keys(): - if k not in field_names: - continue - self.get_field(model, k, v, m2m_m2ms, c_c_path, - new_created) - if '__force_new' in v: - created = v.pop('__force_new') - key = u";".join([u"{}-{}".format(k, v[k]) - for k in sorted(v.keys())]) - # only one forced creation - if attribute in new_created \ - and key in new_created[attribute]: - continue - if attribute not in new_created: - new_created[attribute] = [] - new_created[attribute].append(key) - has_values = bool([1 for k in v if v[k]]) - if has_values: - if self.MODEL_CREATION_LIMIT and \ - model not in self.MODEL_CREATION_LIMIT: - raise self._get_improperly_conf_error(model) - v = model.objects.create(**v) - else: - continue - else: - v['defaults'] = v.get('defaults', {}) - extra_fields = {} - # "File" type is a temp object and can be different - # for the same filename - it must be treated - # separately - for field in model._meta.fields: - k = field.name - # attr_class is a FileField attribute - if hasattr(field, 'attr_class') and k in v: - extra_fields[k] = v.pop(k) - if not self.MODEL_CREATION_LIMIT or \ - model in self.MODEL_CREATION_LIMIT: - v, created = model.objects.get_or_create( - **v) - else: - get_v = v.copy() - if 'defaults' in get_v: - get_v.pop('defaults') - try: - v = model.objects.get(**get_v) - except model.DoesNotExist: - raise self._get_does_not_exist_in_db_error( - model, get_v) - changed = False - for k in extra_fields.keys(): - if extra_fields[k]: - changed = True - setattr(v, k, extra_fields[k]) - if changed: - v.save() - for att, objs in m2m_m2ms: - if type(objs) not in (list, tuple): - objs = [objs] - for obj in objs: - getattr(v, att).add(obj) - if self.import_instance \ - and hasattr(v, 'imports') and created: - v.imports.add(self.import_instance) - m2ms.append((attribute, v)) - elif hasattr(field_object, 'rel') and field_object.rel: - if type(data[attribute]) == dict: - # put history_modifier for every created item - if 'history_modifier' in \ - field_object.rel.to._meta.get_all_field_names(): - data[attribute]['history_modifier'] = \ - self.history_modifier - try: - c_path.append(attribute) - data[attribute], created = self.get_object( - field_object.rel.to, data[attribute].copy(), c_path) - except ImporterError, msg: - self.errors.append((self.idx_line, None, msg)) - data[attribute] = None - elif type(data[attribute]) == list: - data[attribute] = data[attribute][0] + if field_object.many_to_many: + m2ms += self._get_field_m2m(attribute, data, c_path, + new_created, field_object) + return + if not hasattr(field_object, 'rel') or not field_object.rel: + return + if type(data[attribute]) == list: + # extract the first item from list + # be careful if the list has more than one item this is arbitrary + if len(data[attribute]) > 1: + logger.warning( + 'Import {}: {} has many when only one is expected. Get ' + 'the first one but it is not OK!'.format( + self.import_instance, attribute)) + data[attribute] = data[attribute][0] + return + if type(data[attribute]) != dict: + # we treat only dict formated values + return + # put history_modifier for every created item + if 'history_modifier' in get_all_field_names(field_object.rel.to): + data[attribute]['history_modifier'] = \ + self.history_modifier + try: + c_path.append(attribute) + data[attribute], created = self.get_object( + field_object.rel.to, data[attribute].copy(), c_path) + except ImporterError, msg: + self.errors.append((self.idx_line, None, msg)) + data[attribute] = None def get_object(self, cls, data, path=[]): m2ms = [] if type(data) != dict: + # if data is not a dict we don't know what to do return data, False + is_empty = not bool( [k for k in data if k not in ('history_modifier', 'defaults') and data[k]]) if is_empty: + # if no value, no creation return None, False c_path = path[:] @@ -1408,9 +1544,8 @@ class Importer(object): data.pop(attribute) continue if not data[attribute]: - field_object, model, direct, m2m = \ - cls._meta.get_field_by_name(attribute) - if m2m: + field_object = cls._meta.get_field(attribute) + if field_object.many_to_many: data.pop(attribute) continue if attribute != '__force_new': @@ -1436,7 +1571,7 @@ class Importer(object): if type(create_dict[k]) == dict: create_dict.pop(k) # File doesn't like deepcopy - if type(create_dict[k]) == File: + elif type(create_dict[k]) == File: create_dict[k] = copy.copy(data[k]) # default values |