diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
| -rw-r--r-- | ishtar_common/data_importer.py | 149 | 
1 files changed, 110 insertions, 39 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 27714458b..992025bbb 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -63,14 +63,15 @@ class ImportFormater(object):      def report_error(self, *args):          return -    def init(self, vals, output=None, import_instance=None): +    def init(self, vals, output=None, choose_default=False, import_instance=None):          try:              lst = iter(self.formater)          except TypeError:              lst = [self.formater]          for formater in lst:              if formater: -                formater.check(vals, output, +                formater.check(vals, output, self.comment, +                               choose_default=choose_default,                                 import_instance=import_instance)      def post_process(self, obj, context, value, owner=None): @@ -93,9 +94,20 @@ class Formater(object):      def format(self, value):          return value -    def check(self, values, output=None, import_instance=None): +    def check(self, values, output=None, comment='', choose_default=False, +              import_instance=None):          return +class ChoiceChecker(object): +    def report_new(self, comment): +        if not self.new_keys: +            return +        msg = u"For \"%s\" these new associations have been made:\n" % comment +        sys.stderr.write(msg.encode('utf-8')) +        for k in self.new_keys: +            msg = u'"%s";"%s"\n' % (k, self.new_keys[k]) +            sys.stderr.write(msg.encode('utf-8')) +  class UnicodeFormater(Formater):      def __init__(self, max_length, clean=False, re_filter=None, notnull=False,                   db_target=None): @@ -186,7 +198,7 @@ class IntegerFormater(Formater):              raise ValueError(_(u"\"%(value)s\" is not an integer") % {                                                                   'value':value}) -class StrChoiceFormater(Formater): +class StrChoiceFormater(Formater, ChoiceChecker):      def __init__(self, choices, strict=False, equiv_dict={}, model=None,                   cli=False, many_split='', db_target=None):          self.choices = list(choices) @@ -197,6 +209,7 @@ class StrChoiceFormater(Formater):          self.db_target = db_target          self.create = False          self.missings = set() +        self.new_keys = {}          self.many_split = many_split          for key, value in self.choices:              value = unicode(value) @@ -222,8 +235,9 @@ class StrChoiceFormater(Formater):      def prepare(self, value):          return unicode(value).strip() -    def _get_choices(self): -        msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ +    def _get_choices(self, comment=''): +        msgstr = comment + u" - " +        msgstr += unicode(_(u"Choice for \"%s\" is not available. "\                             u"Which one is relevant?\n"))          idx = -1          for idx, choice in enumerate(self.choices): @@ -236,8 +250,9 @@ class StrChoiceFormater(Formater):          msgstr += unicode(_(u"%d. None of the above - skip")) % idx + u"\n"          return msgstr, idx -    def check(self, values, output=None, import_instance=None): -        if not output or output == 'silent': +    def check(self, values, output=None, comment='', choose_default=False, +              import_instance=None): +        if (not output or output == 'silent') and not choose_default:              return          if self.many_split:              new_values = [] @@ -250,14 +265,18 @@ class StrChoiceFormater(Formater):              value = self.prepare(value)              if value in self.equiv_dict:                  continue -            if output != 'cli': +            if output != 'cli' and not choose_default:                  self.missings.add(value)                  continue -            msgstr, idx = self._get_choices() +            msgstr, idx = self._get_choices(comment)              res = None +            if choose_default: +                res = 1              while res not in range(1, idx+1): -                sys.stdout.write(msgstr % value) -                res = raw_input(">>> ") +                msg = msgstr % value +                sys.stdout.write(msg.encode('utf-8')) +                sys.stdout.write("\n>>> ") +                res = raw_input()                  try:                      res = int(res)                  except ValueError: @@ -269,10 +288,12 @@ class StrChoiceFormater(Formater):                      v = self.model.objects.get(pk=v)                  self.equiv_dict[value] = v                  self.add_key(v, value) +                self.new_keys[value] = v              elif self.create and res == len(self.choices):                  self.equiv_dict[value] = self.new(base_value)                  self.choices.append((self.equiv_dict[value].pk,                                       unicode(self.equiv_dict[value]))) +                self.new_keys[value] = unicode(self.equiv_dict[value])              else:                  self.equiv_dict[value] = None          if output == 'db' and self.db_target: @@ -287,6 +308,8 @@ class StrChoiceFormater(Formater):                          TargetKey.objects.create(**q)                      except IntegrityError:                          pass +        if output == 'cli': +            self.report_new(comment)      def new(self, value):          return @@ -312,6 +335,7 @@ class TypeFormater(StrChoiceFormater):          self.db_target = db_target          self.missings = set()          self.equiv_dict, self.choices = {}, [] +        self.new_keys = {}          for item in model.objects.all():              self.choices.append((item.pk, unicode(item)))              for key in item.get_keys(): @@ -336,21 +360,25 @@ class TypeFormater(StrChoiceFormater):          return self.model.objects.create(**values)  class DateFormater(Formater): -    def __init__(self, date_format="%d/%m/%Y", db_target=None): -        self.date_format = date_format +    def __init__(self, date_formats=["%d/%m/%Y"], db_target=None): +        self.date_formats = date_formats +        if type(date_formats) not in (list, tuple): +            self.date_formats = [self.date_formats]          self.db_target = db_target      def format(self, value):          value = value.strip()          if not value:              return -        try: -            return datetime.datetime.strptime(value, self.date_format).date() -        except: -            raise ValueError(_(u"\"%(value)s\" is not a valid date") % { +        for date_format in self.date_formats: +            try: +                return datetime.datetime.strptime(value, date_format).date() +            except: +                continue +        raise ValueError(_(u"\"%(value)s\" is not a valid date") % {                                                             'value':value}) -class StrToBoolean(Formater): +class StrToBoolean(Formater, ChoiceChecker):      def __init__(self, choices={}, cli=False, strict=False, db_target=None):          self.dct = copy.copy(choices)          self.cli = cli @@ -371,6 +399,7 @@ class StrToBoolean(Formater):                  else:                      v = None                  self.dct[value] = v +        self.new_keys = {}      def prepare(self, value):          value = unicode(value).strip() @@ -378,10 +407,12 @@ class StrToBoolean(Formater):              value = slugify(value)          return value -    def check(self, values, output=None, import_instance=None): -        if not output or output == 'silent': +    def check(self, values, output=None, comment='', choose_default=False, +              import_instance=None): +        if (not output or output == 'silent') and not choose_default:              return -        msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ +        msgstr = comment + u" - " +        msgstr += unicode(_(u"Choice for \"%s\" is not available. "\                             u"Which one is relevant?\n"))          msgstr += u"1. True\n"          msgstr += u"2. False\n" @@ -390,13 +421,17 @@ class StrToBoolean(Formater):              value = self.prepare(value)              if value in self.dct:                  continue -            if not self.cli: +            if output != 'cli' and not choose_default:                  self.missings.add(value)                  continue              res = None +            if choose_default: +                res = 1              while res not in range(1, 4): -                sys.stdout.write(msgstr % value) -                res = raw_input(">>> ") +                msg = msgstr % value +                sys.stdout.write(msg.encode('utf-8')) +                sys.stdout.write("\n>>> ") +                res = raw_input()                  try:                      res = int(res)                  except ValueError: @@ -407,15 +442,19 @@ class StrToBoolean(Formater):                  self.dct[value] = False              else:                  self.dct[value] = None +            self.new_keys[value] = unicode(self.dct[value])          if output == 'db' and self.db_target:              from ishtar_common.models import TargetKey              for missing in self.missings:                  try:                      q = {'target':self.db_target, 'key':missing,                           'associated_import':import_instance} -                    TargetKey.objects.create(**q) +                    if not TargetKey.objects.filter(**q).count(): +                        TargetKey.objects.create(**q)                  except IntegrityError:                      pass +        if output == 'cli': +            self.report_new(comment)      def format(self, value):          value = self.prepare(value) @@ -424,13 +463,22 @@ class StrToBoolean(Formater):  logger = logging.getLogger(__name__) +def get_object_from_path(obj, path): +    for k in path.split('__')[:-1]: +        if not hasattr(obj, k): +            return +        obj = getattr(obj, k) +    return obj +  class Importer(object):      DESC = ""      LINE_FORMAT = []      OBJECT_CLS = None      IMPORTED_LINE_FIELD = None      UNICITY_KEYS = [] +    EXTRA_DEFAULTS = {}      DEFAULTS = {} +    STR_CUT = {}      ERRORS = {          'header_check':_(u"The given file is not correct. Check the file "                    u"format. If you use a CSV file: check that column separator " @@ -464,6 +512,12 @@ class Importer(object):          self.line_format = copy.copy(self.LINE_FORMAT)          self.import_instance = import_instance          self._defaults = self.DEFAULTS.copy() +        # EXTRA_DEFAULTS are for multiple inheritance +        if self.EXTRA_DEFAULTS: +            for k in self.EXTRA_DEFAULTS: +                if k not in self._defaults: +                    self._defaults[k] = {} +                self._defaults[k].update(self.EXTRA_DEFAULTS[k])          self.history_modifier = history_modifier          self.output = output          if not self.history_modifier: @@ -474,7 +528,10 @@ class Importer(object):                  self.history_modifier = User.objects.filter(                                  is_superuser=True).order_by('pk')[0] -    def initialize(self, table, output='silent'): +    def post_processing(self, item, data): +        return item + +    def initialize(self, table, output='silent', choose_default=False):          """          copy vals in columns and initialize formaters          * output: @@ -496,12 +553,12 @@ class Importer(object):                  vals[idx_col].append(val)          for idx, formater in enumerate(self.line_format):              if formater and idx < len(vals): -                formater.init(vals[idx], output, +                formater.init(vals[idx], output, choose_default=choose_default,                                import_instance=self.import_instance) -    def importation(self, table, initialize=True): +    def importation(self, table, initialize=True, choose_default=False):          if initialize: -            self.initialize(table, self.output) +            self.initialize(table, self.output, choose_default=choose_default)          self._importation(table)      @classmethod @@ -566,10 +623,10 @@ class Importer(object):                      time_by_item = ellapsed/idx_line                      if time_by_item:                          left = ((total - idx_line)*time_by_item).seconds -                txt = "\r* %d/%d" % (idx_line+1, total) +                txt = u"\r* %d/%d" % (idx_line+1, total)                  if left: -                    txt += " (%d seconds left)" % left -                sys.stdout.write(txt) +                    txt += u" (%d seconds left)" % left +                sys.stdout.write(txt.encode('utf-8'))                  sys.stdout.flush()              try:                  self._line_processing(idx_line, line) @@ -629,8 +686,8 @@ class Importer(object):          if 'history_modifier' in \                         self.OBJECT_CLS._meta.get_all_field_names():              data['history_modifier'] = self.history_modifier -        obj, created = self.get_object(self.OBJECT_CLS, data) +        obj, created = self.get_object(self.OBJECT_CLS, data)          if self.import_instance and hasattr(obj, 'imports') \             and created:              obj.imports.add(self.import_instance) @@ -677,6 +734,8 @@ class Importer(object):          for formater, val in self._post_processing:              formater.post_process(obj, data, val, owner=self.history_modifier) +        obj = self.post_processing(obj, data) +      def _row_processing(self, c_row, idx_col, idx_line, val, data):          if idx_col >= len(self.line_format):              return @@ -851,12 +910,19 @@ class Importer(object):                                                      self.history_modifier                      data[attribute], created = self.get_object(                                     field_object.rel.to, data[attribute], c_path) +            # default values              path = tuple(path)              if path in self._defaults:                  for k in self._defaults[path]:                      if k not in data or not data[k]:                          data[k] = self._defaults[path][k] +            # pre treatment +            if path in self.STR_CUT: +                for k in self.STR_CUT[path]: +                    if k in data and data[k]: +                        data[k] = unicode(data[k])[:self.STR_CUT[path][k]] +              # filter default values              create_dict = copy.deepcopy(data)              for k in create_dict.keys(): @@ -876,11 +942,7 @@ class Importer(object):                          obj.imports.add(self.import_instance)                  except IntegrityError as e:                      raise IntegrityError(e.message) -                except: -                    q = cls.objects.filter(**create_dict) -                    if not q.count(): -                        raise ImporterError("Erreur d'import %s, contexte : %s"\ -                                                % (unicode(cls), unicode(data))) +                except cls.MultipleObjectsReturned:                      created = False                      obj = cls.objects.filter(**create_dict).all()[0]                  for attr, value in m2ms: @@ -890,6 +952,15 @@ class Importer(object):                      for v in values:                          getattr(obj, attr).add(v)              except IntegrityError as e: +                message = e.message +                try: +                    message = unicode(e.message.decode('utf-8')) +                except (UnicodeDecodeError, UnicodeDecodeError): +                    message = '' +                try: +                    data = unicode(data) +                except UnicodeDecodeError: +                    data = ''                  raise ImporterError(                      "Erreur d'import %s, contexte : %s, erreur : %s" \                      % (unicode(cls), unicode(data), e.message.decode('utf-8'))) | 
