diff options
Diffstat (limited to 'ishtar_common/data_importer.py')
| -rw-r--r-- | ishtar_common/data_importer.py | 143 | 
1 files changed, 105 insertions, 38 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 333a81877..1e59b574f 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -63,14 +63,15 @@ class ImportFormater(object):      def report_error(self, *args):          return -    def init(self, vals, output=None): +    def init(self, vals, output=None, choose_default=False):          try:              lst = iter(self.formater)          except TypeError:              lst = [self.formater]          for formater in lst:              if formater: -                formater.check(vals, output) +                formater.check(vals, output, self.comment, +                               choose_default=choose_default)      def post_process(self, obj, context, value, owner=None):          raise NotImplemented() @@ -92,9 +93,19 @@ class Formater(object):      def format(self, value):          return value -    def check(self, values, output=None): +    def check(self, values, output=None, comment='', choose_default=False):          return +class ChoiceChecker(object): +    def report_new(self, comment): +        if not self.new_keys: +            return +        msg = u"For \"%s\" these new associations have been made:\n" % comment +        sys.stderr.write(msg.encode('utf-8')) +        for k in self.new_keys: +            msg = u'"%s";"%s"\n' % (k, self.new_keys[k]) +            sys.stderr.write(msg.encode('utf-8')) +  class UnicodeFormater(Formater):      def __init__(self, max_length, clean=False, re_filter=None, notnull=False,                   db_target=None): @@ -183,7 +194,7 @@ class IntegerFormater(Formater):              raise ValueError(_(u"\"%(value)s\" is not an integer") % {                                                                   'value':value}) -class StrChoiceFormater(Formater): +class StrChoiceFormater(Formater, ChoiceChecker):      def __init__(self, choices, strict=False, equiv_dict={}, model=None,                   cli=False, many_split='', db_target=None):          self.choices = list(choices) @@ -194,6 +205,7 @@ class StrChoiceFormater(Formater):          self.db_target = db_target          self.create = False          self.missings = set() +        self.new_keys = {}          self.many_split = many_split          for key, value in self.choices:              value = unicode(value) @@ -219,8 +231,9 @@ class StrChoiceFormater(Formater):      def prepare(self, value):          return unicode(value).strip() -    def _get_choices(self): -        msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ +    def _get_choices(self, comment=''): +        msgstr = comment + u" - " +        msgstr += unicode(_(u"Choice for \"%s\" is not available. "\                             u"Which one is relevant?\n"))          idx = -1          for idx, choice in enumerate(self.choices): @@ -233,8 +246,8 @@ class StrChoiceFormater(Formater):          msgstr += unicode(_(u"%d. None of the above - skip")) % idx + u"\n"          return msgstr, idx -    def check(self, values, output=None): -        if not output or output == 'silent': +    def check(self, values, output=None, comment='', choose_default=False): +        if (not output or output == 'silent') and not choose_default:              return          if self.many_split:              new_values = [] @@ -247,14 +260,18 @@ class StrChoiceFormater(Formater):              value = self.prepare(value)              if value in self.equiv_dict:                  continue -            if output != 'cli': +            if output != 'cli' and not choose_default:                  self.missings.add(value)                  continue -            msgstr, idx = self._get_choices() +            msgstr, idx = self._get_choices(comment)              res = None +            if choose_default: +                res = 1              while res not in range(1, idx+1): -                sys.stdout.write(msgstr % value) -                res = raw_input(">>> ") +                msg = msgstr % value +                sys.stdout.write(msg.encode('utf-8')) +                sys.stdout.write("\n>>> ") +                res = raw_input()                  try:                      res = int(res)                  except ValueError: @@ -266,10 +283,12 @@ class StrChoiceFormater(Formater):                      v = self.model.objects.get(pk=v)                  self.equiv_dict[value] = v                  self.add_key(v, value) +                self.new_keys[value] = v              elif self.create and res == len(self.choices):                  self.equiv_dict[value] = self.new(base_value)                  self.choices.append((self.equiv_dict[value].pk,                                       unicode(self.equiv_dict[value]))) +                self.new_keys[value] = unicode(self.equiv_dict[value])              else:                  self.equiv_dict[value] = None          if output == 'db' and self.db_target: @@ -283,6 +302,8 @@ class StrChoiceFormater(Formater):                          TargetKey.objects.create(**q)                      except IntegrityError:                          pass +        if output == 'cli': +            self.report_new(comment)      def new(self, value):          return @@ -308,6 +329,7 @@ class TypeFormater(StrChoiceFormater):          self.db_target = db_target          self.missings = set()          self.equiv_dict, self.choices = {}, [] +        self.new_keys = {}          for item in model.objects.all():              self.choices.append((item.pk, unicode(item)))              for key in item.get_keys(): @@ -332,21 +354,25 @@ class TypeFormater(StrChoiceFormater):          return self.model.objects.create(**values)  class DateFormater(Formater): -    def __init__(self, date_format="%d/%m/%Y", db_target=None): -        self.date_format = date_format +    def __init__(self, date_formats=["%d/%m/%Y"], db_target=None): +        self.date_formats = date_formats +        if type(date_formats) not in (list, tuple): +            self.date_formats = [self.date_formats]          self.db_target = db_target      def format(self, value):          value = value.strip()          if not value:              return -        try: -            return datetime.datetime.strptime(value, self.date_format).date() -        except: -            raise ValueError(_(u"\"%(value)s\" is not a valid date") % { +        for date_format in self.date_formats: +            try: +                return datetime.datetime.strptime(value, date_format).date() +            except: +                continue +        raise ValueError(_(u"\"%(value)s\" is not a valid date") % {                                                             'value':value}) -class StrToBoolean(Formater): +class StrToBoolean(Formater, ChoiceChecker):      def __init__(self, choices={}, cli=False, strict=False, db_target=None):          self.dct = copy.copy(choices)          self.cli = cli @@ -367,6 +393,7 @@ class StrToBoolean(Formater):                  else:                      v = None                  self.dct[value] = v +        self.new_keys = {}      def prepare(self, value):          value = unicode(value).strip() @@ -374,10 +401,11 @@ class StrToBoolean(Formater):              value = slugify(value)          return value -    def check(self, values, output=None): -        if not output or output == 'silent': +    def check(self, values, output=None, comment='', choose_default=False): +        if (not output or output == 'silent') and not choose_default:              return -        msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ +        msgstr = comment + u" - " +        msgstr += unicode(_(u"Choice for \"%s\" is not available. "\                             u"Which one is relevant?\n"))          msgstr += u"1. True\n"          msgstr += u"2. False\n" @@ -386,13 +414,17 @@ class StrToBoolean(Formater):              value = self.prepare(value)              if value in self.dct:                  continue -            if not self.cli: +            if output != 'cli' and not choose_default:                  self.missings.add(value)                  continue              res = None +            if choose_default: +                res = 1              while res not in range(1, 4): -                sys.stdout.write(msgstr % value) -                res = raw_input(">>> ") +                msg = msgstr % value +                sys.stdout.write(msg.encode('utf-8')) +                sys.stdout.write("\n>>> ") +                res = raw_input()                  try:                      res = int(res)                  except ValueError: @@ -411,6 +443,9 @@ class StrToBoolean(Formater):                      TargetKey.objects.create(**q)                  except IntegrityError:                      pass +            self.new_keys[value] = unicode(self.dct[value]) +        if output == 'cli': +            self.report_new(comment)      def format(self, value):          value = self.prepare(value) @@ -419,13 +454,22 @@ class StrToBoolean(Formater):  logger = logging.getLogger(__name__) +def get_object_from_path(obj, path): +    for k in path.split('__')[:-1]: +        if not hasattr(obj, k): +            return +        obj = getattr(obj, k) +    return obj +  class Importer(object):      DESC = ""      LINE_FORMAT = []      OBJECT_CLS = None      IMPORTED_LINE_FIELD = None      UNICITY_KEYS = [] +    EXTRA_DEFAULTS = {}      DEFAULTS = {} +    STR_CUT = {}      ERRORS = {          'header_check':_(u"The given file is not correct. Check the file "                    u"format. If you use a CSV file: check that column separator " @@ -459,6 +503,12 @@ class Importer(object):          self.line_format = copy.copy(self.LINE_FORMAT)          self.import_instance = import_instance          self._defaults = self.DEFAULTS.copy() +        # EXTRA_DEFAULTS are for multiple inheritance +        if self.EXTRA_DEFAULTS: +            for k in self.EXTRA_DEFAULTS: +                if k not in self._defaults: +                    self._defaults[k] = {} +                self._defaults[k].update(self.EXTRA_DEFAULTS[k])          self.history_modifier = history_modifier          self.output = output          if not self.history_modifier: @@ -469,7 +519,10 @@ class Importer(object):                  self.history_modifier = User.objects.filter(                                  is_superuser=True).order_by('pk')[0] -    def initialize(self, table, output='silent'): +    def post_processing(self, item, data): +        return item + +    def initialize(self, table, output='silent', choose_default=False):          """          copy vals in columns and initialize formaters          * output: @@ -491,11 +544,11 @@ class Importer(object):                  vals[idx_col].append(val)          for idx, formater in enumerate(self.line_format):              if formater and idx < len(vals): -                formater.init(vals[idx], output) +                formater.init(vals[idx], output, choose_default=choose_default) -    def importation(self, table, initialize=True): +    def importation(self, table, initialize=True, choose_default=False):          if initialize: -            self.initialize(table, self.output) +            self.initialize(table, self.output, choose_default=choose_default)          self._importation(table)      @classmethod @@ -560,10 +613,10 @@ class Importer(object):                      time_by_item = ellapsed/idx_line                      if time_by_item:                          left = ((total - idx_line)*time_by_item).seconds -                txt = "\r* %d/%d" % (idx_line+1, total) +                txt = u"\r* %d/%d" % (idx_line+1, total)                  if left: -                    txt += " (%d seconds left)" % left -                sys.stdout.write(txt) +                    txt += u" (%d seconds left)" % left +                sys.stdout.write(txt.encode('utf-8'))                  sys.stdout.flush()              try:                  self._line_processing(idx_line, line) @@ -623,8 +676,8 @@ class Importer(object):          if 'history_modifier' in \                         self.OBJECT_CLS._meta.get_all_field_names():              data['history_modifier'] = self.history_modifier -        obj, created = self.get_object(self.OBJECT_CLS, data) +        obj, created = self.get_object(self.OBJECT_CLS, data)          if self.import_instance and hasattr(obj, 'imports'):              obj.imports.add(self.import_instance) @@ -669,6 +722,8 @@ class Importer(object):          for formater, val in self._post_processing:              formater.post_process(obj, data, val, owner=self.history_modifier) +        obj = self.post_processing(obj, data) +      def _row_processing(self, c_row, idx_col, idx_line, val, data):          if idx_col >= len(self.line_format):              return @@ -836,12 +891,19 @@ class Importer(object):                                                      self.history_modifier                      data[attribute], created = self.get_object(                                     field_object.rel.to, data[attribute], c_path) +            # default values              path = tuple(path)              if path in self._defaults:                  for k in self._defaults[path]:                      if k not in data or not data[k]:                          data[k] = self._defaults[path][k] +            # pre treatment +            if path in self.STR_CUT: +                for k in self.STR_CUT[path]: +                    if k in data and data[k]: +                        data[k] = unicode(data[k])[:self.STR_CUT[path][k]] +              # filter default values              create_dict = copy.deepcopy(data)              for k in create_dict.keys(): @@ -860,11 +922,7 @@ class Importer(object):                          obj.imports.add(self.import_instance)                  except IntegrityError as e:                      raise IntegrityError(e.message) -                except: -                    q = cls.objects.filter(**create_dict) -                    if not q.count(): -                        raise ImporterError("Erreur d'import %s, contexte : %s"\ -                                                % (unicode(cls), unicode(data))) +                except cls.MultipleObjectsReturned:                      created = False                      obj = cls.objects.filter(**create_dict).all()[0]                  for attr, value in m2ms: @@ -874,6 +932,15 @@ class Importer(object):                      for v in values:                          getattr(obj, attr).add(v)              except IntegrityError as e: +                message = e.message +                try: +                    message = unicode(e.message.decode('utf-8')) +                except (UnicodeDecodeError, UnicodeDecodeError): +                    message = '' +                try: +                    data = unicode(data) +                except UnicodeDecodeError: +                    data = ''                  raise ImporterError(                      "Erreur d'import %s, contexte : %s, erreur : %s" \                      % (unicode(cls), unicode(data), e.message.decode('utf-8'))) | 
