diff options
Diffstat (limited to 'ishtar_common')
| -rw-r--r-- | ishtar_common/data_importer.py | 604 | ||||
| -rw-r--r-- | ishtar_common/unicode_csv.py | 79 | 
2 files changed, 683 insertions, 0 deletions
| diff --git a/ishtar_common/data_importer.py b/ishtar_common/data_importer.py new file mode 100644 index 000000000..86285e33e --- /dev/null +++ b/ishtar_common/data_importer.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (C) 2013-2014  Étienne Loks  <etienne.loks_AT_peacefrogsDOTnet> + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. + +# See the file COPYING for details. + +""" +# Usage exemple (extracted from simulabio application) + +class HarvestPlotImporter(Importer): +    LINE_FORMAT = [ +               ImportFormater('name', Importer.get_unicode_formater(100)), +               ImportFormater('plot_group_number', +                              Importer.get_unicode_formater(3), required=False), +               ImportFormater('geographical_area', unicode, required=False), +               ImportFormater('soil_type', Importer.choices_check(SOIL_TYPE), +                                                            required=False), +               ImportFormater('cow_access', Importer.boolean_formater), +               ImportFormater('area', Importer.float_formater), +               ImportFormater('remark', unicode, required=False), +               ImportFormater('diagnostic', Importer.boolean_formater), +               ImportFormater('project', Importer.boolean_formater), +               ImportFormater('harvest_n2', 'harvest_formater', required=False), +               ImportFormater('harvest_n1', 'harvest_formater', required=False), +               ImportFormater('harvest', 'harvest_formater'), +               ImportFormater('harvest_setting', 'harvest_formater', +                              through=HarvestTransition, +                              through_key='plot', +                              through_dict={'year':1}, +                              through_unicity_keys=['plot', 'year'], +                              required=False), +               ImportFormater('harvest_setting', 'harvest_formater', +                              through=HarvestTransition, +                              through_key='plot', +                              through_dict={'year':2}, +                              through_unicity_keys=['plot', 'year'], +                              required=False), +               ImportFormater('harvest_setting', 'harvest_formater', +                              through=HarvestTransition, +                              through_key='plot', +                              through_dict={'year':3}, +                              through_unicity_keys=['plot', 'year'], +                              required=False), +               ImportFormater('harvest_setting', 'harvest_formater', +                              through=HarvestTransition, +                              through_key='plot', +                              through_dict={'year':4}, +                              through_unicity_keys=['plot', 'year'], +                              required=False), +               ImportFormater('harvest_setting', 'harvest_formater', +                              through=HarvestTransition, +                              through_key='plot', +                              through_dict={'year':5}, +                              through_unicity_keys=['plot', 'year'], +                              required=False), +               ImportFormater('harvest_setting', 'harvest_formater', +                              through=HarvestTransition, +                              through_key='plot', +                              through_dict={'year':6}, +                              through_unicity_keys=['plot', 'year'], +                              required=False), +               ] +    OBJECT_CLS = HarvestPlots +    UNICITY_KEYS = [] + +    def __init__(self, study, skip_first_line=None): +        # get the reference header +        dct = {'separator':settings.CSV_DELIMITER} +        dct['data'] = Harvest.objects.filter(available=True).all() +        reference_file = render_to_string('simulabio/files/parcelles_ref.csv', +                                          dct) +        reference_header = unicode_csv_reader( +                                [reference_file.split('\n')[0]]).next() +        super(HarvestPlotImporter, self).__init__( +                                           skip_first_line=skip_first_line, +                                           reference_header=reference_header) +        self.study = study +        self.default_vals = {'study':self.study} + +    def harvest_formater(self, value): +        value = value.strip() +        if not value: +            return +        try: +            harvest = Harvest.objects.get(name__iexact=value) +        except ObjectDoesNotExist: +            raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { +                'value':value, +                'values':u", ".join([val.name +                        for val in Harvest.objects.filter(available=True)]) +                }) +        hs, created = HarvestSettings.objects.get_or_create(study=self.study, +                                                            harvest=harvest) +        if created: +            self.message = _(u"\"%(harvest)s\" has been added in your settings " +                             u"don't forget to fill yields for this harvest.") \ +                             % {'harvest':harvest.name} +        return hs + +class HarvestPlotsImportForm(forms.Form): +    csv_file = forms.FileField(label=_(u"Plot list file (CSV)")) + +    def save(self, study): +        csv_file = self.cleaned_data['csv_file'] +        importer = models.HarvestPlotImporter(study, skip_first_line=True) +        # some softwares (at least Gnumeric) convert CSV file to utf-8 no matter +        # what the CSV source encoding is +        encodings = [settings.ENCODING, 'utf-8'] +        for encoding in encodings: +            try: +                importer.importation(unicode_csv_reader( +                                [line.decode(encoding) +                                 for line in csv_file.readlines()])) +            except ImporterError, e: +                if e.type == ImporterError.HEADER and encoding != encodings[-1]: +                    csv_file.seek(0) +                    continue +                return 0, [[0, 0, e.msg]], [] +            except UnicodeDecodeError, e: +                return 0, [[0, 0, Importer.ERRORS['header_check']]], [] +            break +        return importer.number_imported, importer.errors, importer.messages +""" + +import copy, csv, datetime, logging, sys +from tempfile import NamedTemporaryFile + +from django.contrib.auth.models import User +from django.db import DatabaseError +from django.template.defaultfilters import slugify +from django.utils.translation import ugettext_lazy as _ + +from ishtar_common.unicode_csv import UnicodeWriter + +class ImportFormater(object): +    def __init__(self, field_name, formater=None, required=True, through=None, +                through_key=None, through_dict=None, through_unicity_keys=None, +                duplicate_field=None, regexp=None, regexp_formater_args=[], +                reverse_for_test=None, comment=""): +        self.field_name = field_name +        self.formater = formater +        self.required = required +        self.through = through +        self.through_key = through_key +        self.through_dict = through_dict +        self.through_unicity_keys = through_unicity_keys +        self.duplicate_field = duplicate_field +        self.regexp = regexp +        self.regexp_formater_args = regexp_formater_args +        self.reverse_for_test = reverse_for_test +        self.comment = comment + +    def __unicode__(self): +        return self.field_name + +    def report_succes(self, *args): +        return + +    def report_error(self, *args): +        return + +    def init(self, vals): +        try: +            lst = iter(self.formater) +        except TypeError: +            lst = [self.formater] +        for formater in lst: +            formater.check(vals) + +class ImporterError(Exception): +    STANDARD = 'S' +    HEADER = 'H' +    def __init__(self, message, type='S'): +        self.msg = message +        self.type = type +    def __str__(self): +        return self.msg + +class Formater(object): +    def format(self, value): +        return value + +    def check(self, values): +        return + +class UnicodeFormater(Formater): +    def __init__(self, max_length, clean=False, re_filter=None): +        self.max_length = max_length +        self.clean = clean +        self.re_filter = re_filter + +    def format(self, value): +        try: +            value = unicode(value.strip()) +            if self.re_filter: +                m = self.re_filter.match(value) +                if m: +                    value = u"".join(m.groups()) +            if self.clean: +                if value.startswith(","): +                    value = value[1:] +                if value.endswith(","): +                    value = value[:-1] +        except UnicodeDecodeError: +            return +        if len(value) > self.max_length: +            raise ValueError(_(u"\"%(value)s\" is too long. "\ +                               u"The max length is %(length)d characters." +                               ) % {'value':value, 'length':self.max_length}) +        return value + +class BooleanFormater(Formater): +    def format(self, value): +        value = value.strip().upper() +        if value in ('1', 'OUI', 'VRAI', 'YES', 'TRUE'): +            return True +        if value in ('', '0', 'NON', 'FAUX', 'NO', 'FALSE'): +            return False +        raise ValueError(_(u"\"%(value)s\" not equal to yes or no") % { +                                                                 'value':value}) + +class FloatFormater(Formater): +    def format(self, value): +        value = value.strip().replace(',', '.') +        if not value: +            return +        try: +            return float(value) +        except ValueError: +            raise ValueError(_(u"\"%(value)s\" is not a float") % { +                                                                 'value':value}) + +class StrChoiceFormater(Formater): +    def __init__(self, choices, strict=False, equiv_dict={}, cli=False): +        self.choices = choices +        self.strict = strict +        self.equiv_dict = copy.deepcopy(equiv_dict) +        self.cli = cli +        self.missings = set() +        for key, value in self.choices: +            value = unicode(value) +            if not self.strict: +                value = slugify(value) +            if value not in self.equiv_dict: +                self.equiv_dict[value] = key + +    def prepare(self, value): +        return unicode(value).strip() + +    def check(self, values): +        msgstr = unicode(_(u"Choice for \"%s\" is not available. "\ +                           u"Which one is relevant?\n")) +        for idx, choice in enumerate(self.choices): +            msgstr += u"%d. %s\n" % (idx+1, choice[1]) +        msgstr += unicode(_(u"%d. None of the above")) % (idx+2) + u"\n" +        for value in values: +            value = self.prepare(value) +            if value in self.equiv_dict: +                continue +            if not self.cli: +                self.missings.add(value) +                continue +            res = None +            while res not in range(1, len(self.choices)+2): +                sys.stdout.write(msgstr % value) +                res = raw_input(">>> ") +                try: +                    res = int(res) +                except ValueError: +                    pass +            res -= 1 +            if res < len(self.choices): +                self.equiv_dict[value] = self.choices[res] +            else: +                self.equiv_dict[value] = None + +    def format(self, value): +        value = self.prepare(value) +        if not self.strict: +            value = slugify(value) +        if value in self.equiv_dict: +            return self.equiv_dict[value] + +logger = logging.getLogger(__name__) + +class Importer(object): +    LINE_FORMAT = [] +    OBJECT_CLS = None +    UNICITY_KEYS = [] +    DEFAULTS = {} +    ERRORS = { +        'header_check':_(u"The given file is not correct. Check the file " +                  u"format. If you use a CSV file: check that column separator " +                  u"and encoding are similar to the ones used by the reference " +                  u"file."), +        'too_many_cols':_(u"Too many cols (%(user_col)d) when " +                         u"maximum is %(ref_col)d"), +        'no_data':_(u"No data provided"), +        'value_required':_(u"Value is required"), +        'not_enough_cols':_(u"At least %d columns must be filled"), +        'regex_not_match':_(u"The regexp doesn't match.") +        } + +    def __init__(self, skip_first_line=False, reference_header=None, +                 check_col_num=False, test=False, check_validity=True, +                 history_modifier=None): +        """ +         * skip_first_line must be set to True if the data provided has got +           an header. +         * a reference_header can be provided to perform a data compliance +           check. It can be useful to warn about bad parsing. +         * test doesn't write in the database +         * check_validity rewrite a CSV file to be compared +        """ +        self.message = '' +        self.skip_first_line = skip_first_line +        self.reference_header = reference_header +        self.test = test +        self.errors = [] # list of (line, col, message) +        self.messages = [] # list of (line, col, message) +        self.number_imported = 0 +        self.check_col_num = check_col_num +        self.check_validity = check_validity +        self.line_format = copy.copy(self.LINE_FORMAT) +        self._initialized = False +        self._defaults = self.DEFAULTS.copy() +        self.history_modifier = history_modifier +        if not self.history_modifier: +            # get the first admin +            self.history_modifier = User.objects.filter(is_superuser=True +                                                ).order_by('pk')[0] + +    def initialize(self, table): +        # copy vals in columns +        vals = [] +        for idx_line, line in enumerate(table): +            if (self.skip_first_line and not idx_line): +                continue +            for idx_col, val in enumerate(line): +                if idx_col >= len(self.line_format): +                    break +                if idx_col >= len(vals): +                    vals.append([]) +                vals[idx_col].append(val) +        for idx, formater in enumerate(self.line_format): +            formater.init(vals[idx]) +        self._initialized = True + +    def importation(self, table): +        if not self._initialized: +            self.initialize(table) +        if self.check_validity: +            with NamedTemporaryFile(delete=False) as validity_file: +                print(validity_file.name) +                validity_file_writer = UnicodeWriter(validity_file, +                                delimiter=',', quotechar='"', +                                quoting=csv.QUOTE_MINIMAL) +                self._importation(table, validity_file_writer) +        else: +            self._importation(table) + +    def _importation(self, table, validity_file=None): +        table = list(table) +        if not table or not table[0]: +            raise ImporterError(self.ERRORS['no_data'], ImporterError.HEADER) +        if self.check_col_num and len(table[0]) > len(self.line_format): +            raise ImporterError(self.ERRORS['too_many_cols'] % { +                     'user_col':len(table[0]), 'ref_col':len(self.line_format)}) +        self.errors = [] +        self.messages = [] +        self.number_imported = 0 +        # index of the last required column +        for idx_last_col, formater in enumerate(reversed(self.line_format)): +            if formater.required: +                break +        else: +            idx_last_col += 1 +        # min col number to be filled +        min_col_number = len(self.line_format) - idx_last_col +        # check the conformity with the reference header +        if self.reference_header and \ +           self.skip_first_line and \ +           self.reference_header != table[0]: +            raise ImporterError(self.ERRORS['header_check'], +                                type=ImporterError.HEADER) +        now = datetime.datetime.now() +        for idx_line, line in enumerate(table): +            if (self.skip_first_line and not idx_line): +                if validity_file: +                    validity_file.writerow(line) +                continue +            if not line: +                if validity_file: +                    validity_file.writerow([]) +                continue +            throughs = [] # list of (formater, value) +            data = {} +            n = datetime.datetime.now() +            logger.debug('%s - Processing line %d' % (unicode(n-now), idx_line)) +            now = n +            n2 = n +            c_errors = False +            c_row = [] +            for idx_col, val in enumerate(line): +                if idx_col >= len(self.line_format): +                    break +                formater = self.line_format[idx_col] +                if not formater.field_name: +                    if validity_file: +                        c_row.append(val) +                    continue +                if formater.regexp: +                    # multiline regexp is a mess... +                    val = val.replace('\n', '######???#####') +                    match = formater.regexp.match(val) +                    if not match: +                        if formater.required: +                            self.errors.append((idx_line+1, idx_col+1, +                                               self.ERRORS['value_required'])) +                        elif not val.strip(): +                            c_row.append("") +                            continue +                        c_errors = True +                        val = val.replace('######???#####', '\n') +                        self.errors.append((idx_line+1, idx_col+1, +                                 unicode(self.ERRORS['regex_not_match']) + val)) +                        c_row.append("") +                        continue +                    val_group = [v.replace('######???#####', '\n') +                                 for v in match.groups()] +                else: +                    val_group = [val] +                c_values = [] +                for idx_v, v in enumerate(val_group): +                    self.message = '' +                    func = formater.formater +                    if type(func) in (list, tuple): +                        func = func[idx_v] +                    if not callable(func) and type(func) in (unicode, str): +                        func = getattr(self, func) +                    value = None +                    try: +                        if formater.regexp_formater_args: +                            args = [] +                            for idx in formater.regexp_formater_args[idx_v]: +                                args.append(val_group[idx]) +                            value = func.format(*args) +                        else: +                            value = func.format(v) +                    except ValueError, e: +                        c_errors = True +                        self.errors.append((idx_line+1, idx_col+1, e.message)) +                        c_values.append(None) +                        continue +                    if self.message: +                        self.messages.append(self.message) +                    c_values.append(value) +                    if value == None: +                        if formater.required: +                            c_errors = True +                            self.errors.append((idx_line+1, idx_col+1, +                                               self.ERRORS['value_required'])) +                        continue +                    field_name = formater.field_name +                    if type(field_name) in (list, tuple): +                        field_name = field_name[idx_v] +                    field_names = [field_name] +                    if formater.duplicate_field: +                        duplicate_field = formater.duplicate_field +                        if type(duplicate_field) in (list, tuple): +                            duplicate_field = duplicate_field[idx_v] +                        field_names += [duplicate_field] +                    if not formater.through: +                        for field_name in field_names: +                            keys = field_name.split('__') +                            current_data = data +                            for idx, key in enumerate(keys): +                                if idx == (len(keys) - 1): # last +                                    current_data[key] = value +                                elif key not in current_data: +                                    current_data[key] = {} +                                current_data = current_data[key] +                    else: +                        throughs.append((formater, value)) +                if formater.reverse_for_test: +                    c_row.append(formater.reverse_for_test(**c_values)) +                else: +                    c_row.append(unicode(c_values)) +            if validity_file: +                validity_file.writerow(c_row) +            if not c_errors and (idx_col + 1) < min_col_number: +                c_errors = True +                self.errors.append((idx_line+1, idx_col+1, +                               self.ERRORS['not_enough_cols'] % min_col_number)) +            if c_errors: +                continue +            n = datetime.datetime.now() +            logger.debug('* %s - Cols read' % (unicode(n-n2))) +            n2 = n +            if self.test: +                continue +            # manage unicity of items (mainly for updates) +            self.number_imported += 1 +            if self.UNICITY_KEYS: +                data['defaults'] = {} +                for k in data.keys(): +                    if k not in self.UNICITY_KEYS \ +                       and k != 'defaults': +                        data['defaults'][k] = data.pop(k) + +            obj, created = self.get_object(self.OBJECT_CLS, data) + +            if not created and 'defaults' in data: +                for k in data['defaults']: +                    setattr(obj, k, data['defaults'][k]) +                obj.save() +            n = datetime.datetime.now() +            logger.debug('* %s - Item saved' % (unicode(n-n2))) +            n2 = n +            for formater, value in throughs: +                n = datetime.datetime.now() +                logger.debug('* %s - Processing formater %s' % (unicode(n-n2), +                                                        formater.field_name)) +                n2 = n +                data = {} +                if formater.through_dict: +                    data = formater.through_dict.copy() +                if formater.through_key: +                    data[formater.through_key] = obj +                data[formater.field_name] = value +                through_cls = formater.through +                if formater.through_unicity_keys: +                    data['defaults'] = {} +                    for k in data.keys(): +                        if k not in formater.through_unicity_keys \ +                           and k != 'defaults': +                            data['defaults'][k] = data.pop(k) +                t_obj, created = through_cls.objects.get_or_create(**data) +                if not created and 'defaults' in data: +                    for k in data['defaults']: +                        setattr(t_obj, k, data['defaults'][k]) +                    t_obj.save() + +    def get_object(self, cls, data, path=[]): +        m2ms = [] +        if data and type(data) == dict: +            for attribute in data.keys(): +                c_path = path[:] +                if not data[attribute]: +                    continue +                field_object, model, direct, m2m = \ +                                    cls._meta.get_field_by_name(attribute) +                if field_object.rel and type(data[attribute]) == dict: +                    c_path.append(attribute) +                    # put history_modifier for every created item +                    data[attribute]['history_modifier'] = \ +                                                    self.history_modifier +                    data[attribute], created = self.get_object( +                                   field_object.rel.to, data[attribute], c_path) +                if m2m: +                    m2ms.append((attribute, data.pop(attribute))) +            path = tuple(path) +            if path in self._defaults: +                for k in self._defaults[path]: +                    if k not in data or not data[k]: +                        data[k] = self._defaults[path][k] +            obj, created = cls.objects.get_or_create(**data) +            for attr, value in m2ms: +                getattr(obj, attr).add(value) +            return obj, created +        return data + +    def get_csv_errors(self): +        for line, col, error in self.errors: +            print '"%d","%d","%s"' % (line, col, unicode(error)) + +    @classmethod +    def choices_check(cls, choices): +        def function(value): +            choices_dct = dict(choices) +            value = value.strip() +            if not value: +                return +            if value not in choices_dct.values(): +                raise ValueError(_(u"\"%(value)s\" not in %(values)s") % { +                    'value':value, +                    'values':u", ".join([val for val in choices_dct.values()]) +                    }) +            return value +        return function + diff --git a/ishtar_common/unicode_csv.py b/ishtar_common/unicode_csv.py new file mode 100644 index 000000000..d0d39f7fb --- /dev/null +++ b/ishtar_common/unicode_csv.py @@ -0,0 +1,79 @@ +import csv, codecs, cStringIO + +def utf_8_encoder(unicode_csv_data): +    for line in unicode_csv_data: +        yield line.encode('utf-8') + +def unicode_csv_reader(unicode_csv_data, dialect=None, reference_header=[], +                       **kwargs): +    if not dialect: +        dialect = csv.Sniffer().sniff(unicode_csv_data[0]) +        # csv.py don't like unicode +        dialect.delimiter = str(dialect.delimiter) +        dialect.quotechar = str(dialect.quotechar) +    # csv.py doesn't do Unicode; encode temporarily as UTF-8: +    csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), +                            dialect=dialect, **kwargs) +    for row in csv_reader: +        # decode UTF-8 back to Unicode, cell by cell: +        yield [unicode(cell, 'utf-8') for cell in row] + +class UTF8Recoder: +    """ +    Iterator that reads an encoded stream and reencodes the input to UTF-8 +    """ +    def __init__(self, f, encoding): +        self.reader = codecs.getreader(encoding)(f) + +    def __iter__(self): +        return self + +    def next(self): +        return self.reader.next().encode("utf-8") + +class UnicodeReader: +    """ +    A CSV reader which will iterate over lines in the CSV file "f", +    which is encoded in the given encoding. +    """ + +    def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds): +        f = UTF8Recoder(f, encoding) +        self.reader = csv.reader(f, dialect=dialect, **kwds) + +    def next(self): +        row = self.reader.next() +        return [unicode(s, "utf-8") for s in row] + +    def __iter__(self): +        return self + +class UnicodeWriter: +    """ +    A CSV writer which will write rows to CSV file "f", +    which is encoded in the given encoding. +    """ + +    def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds): +        # Redirect output to a queue +        self.queue = cStringIO.StringIO() +        self.writer = csv.writer(self.queue, dialect=dialect, **kwds) +        self.stream = f +        self.encoder = codecs.getincrementalencoder(encoding)() + +    def writerow(self, row): +        self.writer.writerow([s.encode("utf-8") for s in row]) +        # Fetch UTF-8 output from the queue ... +        data = self.queue.getvalue() +        data = data.decode("utf-8") +        # ... and reencode it into the target encoding +        data = self.encoder.encode(data) +        # write to the target stream +        self.stream.write(data) +        # empty queue +        self.queue.truncate(0) + +    def writerows(self, rows): +        for row in rows: +            self.writerow(row) + | 
