diff options
| author | Étienne Loks <etienne.loks@proxience.com> | 2014-11-24 14:50:22 +0100 | 
|---|---|---|
| committer | Étienne Loks <etienne.loks@proxience.com> | 2014-11-24 14:50:22 +0100 | 
| commit | 1094b07f381b658f8325c2723afa2e26b8909ebb (patch) | |
| tree | 660816bc439f5387855dddacf7f6cf107c40de0a /ishtar_common/data_importer.py | |
| parent | 85732c3372a0667e5058929a124ef023850a8ca2 (diff) | |
| download | Ishtar-1094b07f381b658f8325c2723afa2e26b8909ebb.tar.bz2 Ishtar-1094b07f381b658f8325c2723afa2e26b8909ebb.zip | |
Work on SRA importation
Diffstat (limited to 'ishtar_common/data_importer.py')
| -rw-r--r-- | ishtar_common/data_importer.py | 541 | 
1 files changed, 379 insertions, 162 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py index 86285e33e..1d768c6b0 100644 --- a/ishtar_common/data_importer.py +++ b/ishtar_common/data_importer.py @@ -17,139 +17,24 @@  # See the file COPYING for details. -""" -# Usage exemple (extracted from simulabio application) - -class HarvestPlotImporter(Importer): -    LINE_FORMAT = [ -               ImportFormater('name', Importer.get_unicode_formater(100)), -               ImportFormater('plot_group_number', -                              Importer.get_unicode_formater(3), required=False), -               ImportFormater('geographical_area', unicode, required=False), -               ImportFormater('soil_type', Importer.choices_check(SOIL_TYPE), -                                                            required=False), -               ImportFormater('cow_access', Importer.boolean_formater), -               ImportFormater('area', Importer.float_formater), -               ImportFormater('remark', unicode, required=False), -               ImportFormater('diagnostic', Importer.boolean_formater), -               ImportFormater('project', Importer.boolean_formater), -               ImportFormater('harvest_n2', 'harvest_formater', required=False), -               ImportFormater('harvest_n1', 'harvest_formater', required=False), -               ImportFormater('harvest', 'harvest_formater'), -               ImportFormater('harvest_setting', 'harvest_formater', -                              through=HarvestTransition, -                              through_key='plot', -                              through_dict={'year':1}, -                              through_unicity_keys=['plot', 'year'], -                              required=False), -               ImportFormater('harvest_setting', 'harvest_formater', -                              through=HarvestTransition, -                              through_key='plot', -                              through_dict={'year':2}, -                              through_unicity_keys=['plot', 'year'], -                              required=False), -               ImportFormater('harvest_setting', 'harvest_formater', -                              through=HarvestTransition, -                              through_key='plot', -                              through_dict={'year':3}, -                              through_unicity_keys=['plot', 'year'], -                              required=False), -               ImportFormater('harvest_setting', 'harvest_formater', -                              through=HarvestTransition, -                              through_key='plot', -                              through_dict={'year':4}, -                              through_unicity_keys=['plot', 'year'], -                              required=False), -               ImportFormater('harvest_setting', 'harvest_formater', -                              through=HarvestTransition, -                              through_key='plot', -                              through_dict={'year':5}, -                              through_unicity_keys=['plot', 'year'], -                              required=False), -               ImportFormater('harvest_setting', 'harvest_formater', -                              through=HarvestTransition, -                              through_key='plot', -                              through_dict={'year':6}, -                              through_unicity_keys=['plot', 'year'], -                              required=False), -               ] -    OBJECT_CLS = HarvestPlots -    UNICITY_KEYS = [] - -    def __init__(self, study, skip_first_line=None): -        # get the reference header -        dct = {'separator':settings.CSV_DELIMITER} -        dct['data'] = Harvest.objects.filter(available=True).all() -        reference_file = render_to_string('simulabio/files/parcelles_ref.csv', -                                          dct) -        reference_header = unicode_csv_reader( -                                [reference_file.split('\n')[0]]).next() -        super(HarvestPlotImporter, self).__init__( -                                           skip_first_line=skip_first_line, -                                           reference_header=reference_header) -        self.study = study -        self.default_vals = {'study':self.study} - -    def harvest_formater(self, value): -        value = value.strip() -        if not value: -            return -        try: -            harvest = Harvest.objects.get(name__iexact=value) -        except ObjectDoesNotExist: -            raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { -                'value':value, -                'values':u", ".join([val.name -                        for val in Harvest.objects.filter(available=True)]) -                }) -        hs, created = HarvestSettings.objects.get_or_create(study=self.study, -                                                            harvest=harvest) -        if created: -            self.message = _(u"\"%(harvest)s\" has been added in your settings " -                             u"don't forget to fill yields for this harvest.") \ -                             % {'harvest':harvest.name} -        return hs - -class HarvestPlotsImportForm(forms.Form): -    csv_file = forms.FileField(label=_(u"Plot list file (CSV)")) - -    def save(self, study): -        csv_file = self.cleaned_data['csv_file'] -        importer = models.HarvestPlotImporter(study, skip_first_line=True) -        # some softwares (at least Gnumeric) convert CSV file to utf-8 no matter -        # what the CSV source encoding is -        encodings = [settings.ENCODING, 'utf-8'] -        for encoding in encodings: -            try: -                importer.importation(unicode_csv_reader( -                                [line.decode(encoding) -                                 for line in csv_file.readlines()])) -            except ImporterError, e: -                if e.type == ImporterError.HEADER and encoding != encodings[-1]: -                    csv_file.seek(0) -                    continue -                return 0, [[0, 0, e.msg]], [] -            except UnicodeDecodeError, e: -                return 0, [[0, 0, Importer.ERRORS['header_check']]], [] -            break -        return importer.number_imported, importer.errors, importer.messages -""" -  import copy, csv, datetime, logging, sys  from tempfile import NamedTemporaryFile  from django.contrib.auth.models import User -from django.db import DatabaseError +from django.db import DatabaseError, IntegrityError  from django.template.defaultfilters import slugify  from django.utils.translation import ugettext_lazy as _  from ishtar_common.unicode_csv import UnicodeWriter +NEW_LINE_BREAK = '#####@@@#####' +  class ImportFormater(object):      def __init__(self, field_name, formater=None, required=True, through=None,                  through_key=None, through_dict=None, through_unicity_keys=None,                  duplicate_field=None, regexp=None, regexp_formater_args=[], -                reverse_for_test=None, comment=""): +                reverse_for_test=None, force_value=None, post_processing=False, +                concat=False, comment=""):          self.field_name = field_name          self.formater = formater          self.required = required @@ -161,6 +46,12 @@ class ImportFormater(object):          self.regexp = regexp          self.regexp_formater_args = regexp_formater_args          self.reverse_for_test = reverse_for_test +        # write this value even if a value exists +        self.force_value = force_value +        # post process after import +        self.post_processing = post_processing +        # concatenate with existing value +        self.concat = concat          self.comment = comment      def __unicode__(self): @@ -178,7 +69,11 @@ class ImportFormater(object):          except TypeError:              lst = [self.formater]          for formater in lst: -            formater.check(vals) +            if formater: +                formater.check(vals) + +    def post_process(self, obj, context, value, owner=None): +        raise NotImplemented()  class ImporterError(Exception):      STANDARD = 'S' @@ -186,6 +81,7 @@ class ImporterError(Exception):      def __init__(self, message, type='S'):          self.msg = message          self.type = type +      def __str__(self):          return self.msg @@ -214,6 +110,7 @@ class UnicodeFormater(Formater):                      value = value[1:]                  if value.endswith(","):                      value = value[:-1] +                value = value.replace(", , ", ", ")          except UnicodeDecodeError:              return          if len(value) > self.max_length: @@ -243,19 +140,59 @@ class FloatFormater(Formater):              raise ValueError(_(u"\"%(value)s\" is not a float") % {                                                                   'value':value}) +class YearFormater(Formater): +    def format(self, value): +        value = value.strip() +        if not value: +            return +        try: +            value = int(value) +            assert value > 0 and value < (datetime.date.today().year + 30) +        except (ValueError, AssertionError): +            raise ValueError(_(u"\"%(value)s\" is not a valid date") % { +                                                                 'value':value}) + +class YearNoFuturFormater(Formater): +    def format(self, value): +        value = value.strip() +        if not value: +            return +        try: +            value = int(value) +            assert value > 0 and value < (datetime.date.today().year) +        except (ValueError, AssertionError): +            raise ValueError(_(u"\"%(value)s\" is not a valid date") % { +                                                                 'value':value}) + +class IntegerFormater(Formater): +    def format(self, value): +        value = value.strip() +        if not value: +            return +        try: +            return int(value) +        except ValueError: +            raise ValueError(_(u"\"%(value)s\" is not an integer") % { +                                                                 'value':value}) +  class StrChoiceFormater(Formater): -    def __init__(self, choices, strict=False, equiv_dict={}, cli=False): -        self.choices = choices +    def __init__(self, choices, strict=False, equiv_dict={}, model=None, +                 cli=False): +        self.choices = list(choices)          self.strict = strict          self.equiv_dict = copy.deepcopy(equiv_dict)          self.cli = cli +        self.model = model          self.missings = set()          for key, value in self.choices:              value = unicode(value)              if not self.strict:                  value = slugify(value)              if value not in self.equiv_dict: -                self.equiv_dict[value] = key +                v = key +                if model and v: +                    v = model.objects.get(pk=v) +                self.equiv_dict[value] = v      def prepare(self, value):          return unicode(value).strip() @@ -283,7 +220,10 @@ class StrChoiceFormater(Formater):                      pass              res -= 1              if res < len(self.choices): -                self.equiv_dict[value] = self.choices[res] +                v = self.choices[res][0] +                if self.model and v: +                    v = self.model.objects.get(pk=v) +                self.equiv_dict[value] = v              else:                  self.equiv_dict[value] = None @@ -294,11 +234,24 @@ class StrChoiceFormater(Formater):          if value in self.equiv_dict:              return self.equiv_dict[value] +class DateFormater(Formater): +    def __init__(self, date_format="%d/%m/%Y"): +        self.date_format = date_format + +    def format(self, value): +        value = value.strip() +        try: +            return datetime.datetime.strptime(value, self.date_format).date() +        except: +            raise ValueError(_(u"\"%(value)s\" is not a valid date") % { +                                                           'value':value}) +  logger = logging.getLogger(__name__)  class Importer(object):      LINE_FORMAT = []      OBJECT_CLS = None +    IMPORTED_LINE_FIELD = None      UNICITY_KEYS = []      DEFAULTS = {}      ERRORS = { @@ -356,23 +309,264 @@ class Importer(object):                      vals.append([])                  vals[idx_col].append(val)          for idx, formater in enumerate(self.line_format): -            formater.init(vals[idx]) +            if formater: +                formater.init(vals[idx])          self._initialized = True      def importation(self, table): +        self.validity_file = None          if not self._initialized:              self.initialize(table)          if self.check_validity:              with NamedTemporaryFile(delete=False) as validity_file: -                print(validity_file.name) -                validity_file_writer = UnicodeWriter(validity_file, +                self.validity_file = UnicodeWriter(validity_file,                                  delimiter=',', quotechar='"',                                  quoting=csv.QUOTE_MINIMAL) -                self._importation(table, validity_file_writer) +                self._importation(table)          else:              self._importation(table) -    def _importation(self, table, validity_file=None): +    @classmethod +    def _field_name_to_data_dict(cls, field_name, value, data, +                                 force_value=False, concat=False): +        field_names = field_name +        if type(field_names) not in (list, tuple): +            field_names = [field_name] +        for field_name in field_names: +            keys = field_name.split('__') +            current_data = data +            for idx, key in enumerate(keys): +                if idx == (len(keys) - 1): # last +                    if concat: +                        if not value: +                            value = "" +                        current_data[key] = (current_data[key] + u"\n") or u""\ +                                            + value +                    elif force_value and value: +                        current_data[key] = value +                    elif key not in current_data or not current_data[key]: +                        current_data[key] = value +                elif key not in current_data: +                    current_data[key] = {} +                current_data = current_data[key] +        return data + +    def _importation(self, table): +        table = list(table) +        if not table or not table[0]: +            raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) +        if self.check_col_num and len(table[0]) > len(self.line_format): +            raise ImporterError(self.ERRORS['too_many_cols'] % { +                     'user_col':len(table[0]), 'ref_col':len(self.line_format)}) +        self.errors = [] +        self.messages = [] +        self.number_imported = 0 +        # index of the last required column +        for idx_last_col, formater in enumerate(reversed(self.line_format)): +            if formater and formater.required: +                break +        else: +            idx_last_col += 1 +        # min col number to be filled +        self.min_col_number = len(self.line_format) - idx_last_col +        # check the conformity with the reference header +        if self.reference_header and \ +           self.skip_first_line and \ +           self.reference_header != table[0]: +            raise ImporterError(self.ERRORS['header_check'], +                                type=ImporterError.HEADER) +        self.now = datetime.datetime.now() +        for idx_line, line in enumerate(table): +            self._line_processing(idx_line, line) + +    def _line_processing(self, idx_line, line): +        if (self.skip_first_line and not idx_line): +            if self.validity_file: +                self.validity_file.writerow(line) +            return +        if not line: +            if self.validity_file: +                self.validity_file.writerow([]) +            return +        self._throughs = [] # list of (formater, value) +        self._post_processing = [] # list of (formater, value) +        data = {} + +        # keep in database the raw line for testing purpose +        if self.IMPORTED_LINE_FIELD: +            output = io.StringIO() +            writer = csv.writer(output) +            writer.writerow(line) +            data[self.IMPORTED_LINE_FIELD] = output.getvalue() + +        n = datetime.datetime.now() +        logger.debug('%s - Processing line %d' % (unicode(n-self.now), idx_line)) +        self.now = n +        n2 = n +        self.c_errors = False +        c_row = [] +        for idx_col, val in enumerate(line): +            try: +                self._row_processing(c_row, idx_col, idx_line, val, data) +            except: +                pass + +        if self.validity_file: +            self.validity_file.writerow(c_row) +        if not self.c_errors and (idx_col + 1) < self.min_col_number: +            self.c_errors = True +            self.errors.append((idx_line+1, idx_col+1, +                          self.ERRORS['not_enough_cols'] % self.min_col_number)) +        if self.c_errors: +            return +        n = datetime.datetime.now() +        logger.debug('* %s - Cols read' % (unicode(n-n2))) +        n2 = n +        if self.test: +            return +        # manage unicity of items (mainly for updates) +        self.number_imported += 1 +        if self.UNICITY_KEYS: +            data['defaults'] = {} +            for k in data.keys(): +                if k not in self.UNICITY_KEYS \ +                   and k != 'defaults': +                    data['defaults'][k] = data.pop(k) + +        data['history_modifier'] = self.history_modifier +        obj, created = self.get_object(self.OBJECT_CLS, data) + +        if not created and 'defaults' in data: +            for k in data['defaults']: +                setattr(obj, k, data['defaults'][k]) +            obj.save() +        n = datetime.datetime.now() +        logger.debug('* %s - Item saved' % (unicode(n-n2))) +        n2 = n +        for formater, value in self._throughs: +            n = datetime.datetime.now() +            logger.debug('* %s - Processing formater %s' % (unicode(n-n2), +                                                    formater.field_name)) +            n2 = n +            data = {} +            if formater.through_dict: +                data = formater.through_dict.copy() +            if formater.through_key: +                data[formater.through_key] = obj +            data[formater.field_name] = value +            through_cls = formater.through +            if formater.through_unicity_keys: +                data['defaults'] = {} +                for k in data.keys(): +                    if k not in formater.through_unicity_keys \ +                       and k != 'defaults': +                        data['defaults'][k] = data.pop(k) +            t_obj, created = through_cls.objects.get_or_create(**data) +            if not created and 'defaults' in data: +                for k in data['defaults']: +                    setattr(t_obj, k, data['defaults'][k]) +                t_obj.save() + +        for formater, val in self._post_processing: +            formater.post_process(obj, data, val, owner=self.history_modifier) + +    def _row_processing(self, c_row, idx_col, idx_line, val, data): +        if idx_col >= len(self.line_format): +            return + +        formater = self.line_format[idx_col] + +        if formater and formater.post_processing: +            self._post_processing.append((formater, val)) + +        if not formater or not formater.field_name: +            if self.validity_file: +                c_row.append(val) +            return + +        # regex management +        if formater.regexp: +            # multiline regexp is a mess... +            val = val.replace('\n', NEW_LINE_BREAK) +            match = formater.regexp.match(val) +            if not match: +                if formater.required: +                    self.errors.append((idx_line+1, idx_col+1, +                                       self.ERRORS['value_required'])) +                elif not val.strip(): +                    c_row.append("") +                    return +                self.c_errors = True +                val = val.replace(NEW_LINE_BREAK, '\n') +                self.errors.append((idx_line+1, idx_col+1, +                         unicode(self.ERRORS['regex_not_match']) + val)) +                c_row.append("") +                return +            val_group = [v.replace(NEW_LINE_BREAK, '\n') +                         for v in match.groups()] +        else: +            val_group = [val] + +        c_values = [] +        for idx_v, v in enumerate(val_group): +            self.message = '' +            func = formater.formater +            if type(func) in (list, tuple): +                func = func[idx_v] +            if not callable(func) and type(func) in (unicode, str): +                func = getattr(self, func) +            value = None + +            try: +                if formater.regexp_formater_args: +                    args = [] +                    for idx in formater.regexp_formater_args[idx_v]: +                        args.append(val_group[idx]) +                    value = func.format(*args) +                else: +                    value = func.format(v) +            except ValueError, e: +                if formater.required: +                    self.c_errors = True +                self.errors.append((idx_line+1, idx_col+1, e.message)) +                c_values.append(None) +                return + +            if self.message: +                self.messages.append(self.message) +            c_values.append(value) + +            if value == None: +                if formater.required: +                    self.c_errors = True +                    self.errors.append((idx_line+1, idx_col+1, +                                       self.ERRORS['value_required'])) +                return + +            field_name = formater.field_name +            if type(field_name) in (list, tuple): +                field_name = field_name[idx_v] +            field_names = [field_name] +            if formater.duplicate_field: +                duplicate_field = formater.duplicate_field +                if type(duplicate_field) in (list, tuple): +                    duplicate_field = duplicate_field[idx_v] +                field_names += [duplicate_field] + +            if formater.through: +                self._throughs.append((formater, value)) +            else: +                for field_name in field_names: +                    self._field_name_to_data_dict(field_name, +                                              value, data, formater.force_value) +        if formater.reverse_for_test: +            c_row.append(formater.reverse_for_test(**c_values)) +        else: +            c_row.append(unicode(c_values)) + + +    """ +    def _importation(self, table):          table = list(table)          if not table or not table[0]:              raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) @@ -398,33 +592,46 @@ class Importer(object):                                  type=ImporterError.HEADER)          now = datetime.datetime.now()          for idx_line, line in enumerate(table): +            #self._line_processing() +              if (self.skip_first_line and not idx_line): -                if validity_file: -                    validity_file.writerow(line) +                if self.validity_file: +                    self.validity_file.writerow(line)                  continue              if not line: -                if validity_file: -                    validity_file.writerow([]) +                if self.validity_file: +                    self.validity_file.writerow([])                  continue -            throughs = [] # list of (formater, value) +            self.throughs = [] # list of (formater, value) +            self.post_processing = [] # list of (formater, value)              data = {} + +            # keep in database the raw line for testing purpose +            if self.IMPORTED_LINE_FIELD: +                output = io.StringIO() +                writer = csv.writer(output) +                writer.writerow(line) +                data[self.IMPORTED_LINE_FIELD] = output.getvalue() +              n = datetime.datetime.now()              logger.debug('%s - Processing line %d' % (unicode(n-now), idx_line))              now = n              n2 = n -            c_errors = False +            self.c_errors = False              c_row = []              for idx_col, val in enumerate(line): +                #self._row_processing(self, c_row, idx_col, val): +                  if idx_col >= len(self.line_format):                      break                  formater = self.line_format[idx_col]                  if not formater.field_name: -                    if validity_file: +                    if self.validity_file:                          c_row.append(val)                      continue                  if formater.regexp:                      # multiline regexp is a mess... -                    val = val.replace('\n', '######???#####') +                    val = val.replace('\n', NEW_LINE_BREAK)                      match = formater.regexp.match(val)                      if not match:                          if formater.required: @@ -434,12 +641,12 @@ class Importer(object):                              c_row.append("")                              continue                          c_errors = True -                        val = val.replace('######???#####', '\n') +                        val = val.replace(NEW_LINE_BREAK, '\n')                          self.errors.append((idx_line+1, idx_col+1,                                   unicode(self.ERRORS['regex_not_match']) + val))                          c_row.append("")                          continue -                    val_group = [v.replace('######???#####', '\n') +                    val_group = [v.replace(NEW_LINE_BREAK, '\n')                                   for v in match.groups()]                  else:                      val_group = [val] @@ -483,29 +690,26 @@ class Importer(object):                          if type(duplicate_field) in (list, tuple):                              duplicate_field = duplicate_field[idx_v]                          field_names += [duplicate_field] -                    if not formater.through: -                        for field_name in field_names: -                            keys = field_name.split('__') -                            current_data = data -                            for idx, key in enumerate(keys): -                                if idx == (len(keys) - 1): # last -                                    current_data[key] = value -                                elif key not in current_data: -                                    current_data[key] = {} -                                current_data = current_data[key] -                    else: + + +                    if formater.through:                          throughs.append((formater, value)) +                    else: +                        for field_name in field_names: +                            self._field_name_to_data_dict(field_name, +                                                          value, data)                  if formater.reverse_for_test:                      c_row.append(formater.reverse_for_test(**c_values))                  else:                      c_row.append(unicode(c_values)) -            if validity_file: -                validity_file.writerow(c_row) -            if not c_errors and (idx_col + 1) < min_col_number: -                c_errors = True + +            if self.validity_file: +                self.validity_file.writerow(c_row) +            if not self.c_errors and (idx_col + 1) < min_col_number: +                self.c_errors = True                  self.errors.append((idx_line+1, idx_col+1,                                 self.ERRORS['not_enough_cols'] % min_col_number)) -            if c_errors: +            if self.c_errors:                  continue              n = datetime.datetime.now()              logger.debug('* %s - Cols read' % (unicode(n-n2))) @@ -530,7 +734,7 @@ class Importer(object):              n = datetime.datetime.now()              logger.debug('* %s - Item saved' % (unicode(n-n2)))              n2 = n -            for formater, value in throughs: +            for formater, value in self.throughs:                  n = datetime.datetime.now()                  logger.debug('* %s - Processing formater %s' % (unicode(n-n2),                                                          formater.field_name)) @@ -553,6 +757,7 @@ class Importer(object):                      for k in data['defaults']:                          setattr(t_obj, k, data['defaults'][k])                      t_obj.save() +    """      def get_object(self, cls, data, path=[]):          m2ms = [] @@ -563,7 +768,8 @@ class Importer(object):                      continue                  field_object, model, direct, m2m = \                                      cls._meta.get_field_by_name(attribute) -                if field_object.rel and type(data[attribute]) == dict: +                if hasattr(field_object, 'rel') and field_object.rel and \ +                   type(data[attribute]) == dict:                      c_path.append(attribute)                      # put history_modifier for every created item                      data[attribute]['history_modifier'] = \ @@ -577,15 +783,26 @@ class Importer(object):                  for k in self._defaults[path]:                      if k not in data or not data[k]:                          data[k] = self._defaults[path][k] -            obj, created = cls.objects.get_or_create(**data) -            for attr, value in m2ms: -                getattr(obj, attr).add(value) + +            # filter default values +            create_dict = copy.deepcopy(data) +            for k in create_dict.keys(): +                if type(create_dict[k]) == dict: +                    create_dict.pop(k) + +            try: +                obj, created = cls.objects.get_or_create(**create_dict) +                for attr, value in m2ms: +                    getattr(obj, attr).add(value) +            except IntegrityError: +                raise ImporterError("Erreur d'import %s, contexte : %s" \ +                                % (unicode(cls), unicode(data)))              return obj, created          return data      def get_csv_errors(self):          for line, col, error in self.errors: -            print '"%d","%d","%s"' % (line, col, unicode(error)) +            print('"%d","%d","%s"' % (line, col, unicode(error)))      @classmethod      def choices_check(cls, choices): | 
