diff options
Diffstat (limited to 'ishtar_common/wizards.py')
| -rw-r--r-- | ishtar_common/wizards.py | 40 | 
1 files changed, 30 insertions, 10 deletions
| diff --git a/ishtar_common/wizards.py b/ishtar_common/wizards.py index 7fc22f1a9..c8467ca61 100644 --- a/ishtar_common/wizards.py +++ b/ishtar_common/wizards.py @@ -32,7 +32,7 @@ from django.db.models.fields.files import FileField  from django.db.models.fields.related import ManyToManyField  from django.http import HttpResponseRedirect  from django.forms import ValidationError -from django.shortcuts import render_to_response +from django.shortcuts import render_to_response, redirect  from django.template import Context, RequestContext, loader  from django.utils.datastructures import MultiValueDict as BaseMultiValueDict  from django.utils.translation import ugettext_lazy as _ @@ -557,10 +557,15 @@ class Wizard(NamedUrlWizardView):          m2m_items = {}          # clear          # TODO! perf - to be really optimized +        old_m2ms = {}          for model in whole_associated_models:              related_model = getattr(obj, model + 's')              # clear real m2m              if hasattr(related_model, 'clear'): +                old_m2ms[model] = [] +                # stock items in order to not recreate them +                for old_item in related_model.all(): +                    old_m2ms[model].append(old_item)                  related_model.clear()              else:                  for r in related_model.all(): @@ -581,16 +586,32 @@ class Wizard(NamedUrlWizardView):              if value not in m2m_items[key]:                  if type(value) == dict:                      model = related_model.model -                    if issubclass(model, models.BaseHistorizedItem): -                        value['history_modifier'] = self.request.user                      # not m2m -> foreign key                      if not hasattr(related_model, 'clear'):                          assert hasattr(model, 'MAIN_ATTR'), \                              u"Must define a MAIN_ATTR for " + \                              unicode(model.__class__)                          value[getattr(model, 'MAIN_ATTR')] = obj -                    value = model.objects.create(**value) -                    value.save() + +                    # check old links +                    my_old_item = None +                    if key in old_m2ms: +                        for old_item in old_m2ms[key]: +                            is_ok = True +                            for k in value: +                                if is_ok and getattr(old_item, k) != value[k]: +                                    is_ok = False +                                    continue +                            if is_ok: +                                my_old_item = old_item +                                break +                    if my_old_item: +                        value = my_old_item +                    else: +                        if issubclass(model, models.BaseHistorizedItem): +                            value['history_modifier'] = self.request.user +                        value = model.objects.create(**value) +                        value.save()                  # check that an item is not add multiple times (forged forms)                  if value not in related_model.all() and\                          hasattr(related_model, 'add'): @@ -632,7 +653,7 @@ class Wizard(NamedUrlWizardView):          return (to_delete, not_to_delete)      def get_form(self, step=None, data=None, files=None): -        """Manage formset""" +        # Manage formset          if data:              data = data.copy()              if not step: @@ -695,6 +716,7 @@ class Wizard(NamedUrlWizardView):                  #    for k in init[0]:                  #        data[step + '-' + unicode(total_field) + '-' + k] = \                  #                                                      init[0][k] +          data = data or None          form = super(Wizard, self).get_form(step, data, files)          # add autofocus to first field @@ -723,6 +745,7 @@ class Wizard(NamedUrlWizardView):                      if frm.fields[key].widget.source_full is not None:                          frm.fields[key].widget.source_full = unicode(                              frm.fields[key].widget.source_full) + "own/" +          return form      def render_next_step(self, form, **kwargs): @@ -760,10 +783,7 @@ class Wizard(NamedUrlWizardView):          except (ValueError, IndexError):              return super(Wizard, self).post(*args, **kwargs)          self.storage.current_step = wizard_goto_step -        form = self.get_form( -            data=self.storage.get_step_data(self.steps.current), -            files=self.storage.get_step_files(self.steps.current)) -        return self.render(form) +        return redirect(self.get_step_url(wizard_goto_step))      def session_get_keys(self, form_key):          """Get list of available keys for a specific form | 
