diff options
Diffstat (limited to 'ishtar_common/wizards.py')
| -rw-r--r-- | ishtar_common/wizards.py | 56 | 
1 files changed, 43 insertions, 13 deletions
| diff --git a/ishtar_common/wizards.py b/ishtar_common/wizards.py index 70f3caa19..89a155c1c 100644 --- a/ishtar_common/wizards.py +++ b/ishtar_common/wizards.py @@ -835,7 +835,17 @@ class Wizard(NamedUrlWizardView):                              value, created = model.objects.get_or_create(                                  **value)                          else: -                            value = model.objects.create(**value) +                            if 'pk' in value and value['pk']: +                                try: +                                    instance = model.objects.get( +                                        pk=value.pop('pk')) +                                except model.DoesNotExist: +                                    continue +                                for k in value: +                                    setattr(instance, k, value[k]) +                                value = instance +                            else: +                                value = model.objects.create(**value)                          value.save()  # force post_save                  # check that an item is not add multiple times (forged forms)                  if value not in related_model.all() and \ @@ -1252,6 +1262,22 @@ class Wizard(NamedUrlWizardView):                  initial[base_field] = unicode(value)          return initial +    @staticmethod +    def _get_vals_for_instanced_init_for_formset(field, child_obj, vals): +        if hasattr(child_obj, field): +            value = getattr(child_obj, field) +            if hasattr(value, 'pk'): +                value = value.pk +            if value is not None: +                vals[field] = unicode(value) +        elif hasattr(child_obj, field + "s"): +            # M2M +            vals.setlist(field, [ +                unicode(v.pk) +                for v in getattr(child_obj, field + "s").all() +            ]) +        return vals +      def _get_instanced_init_for_formset(self, obj, current_step, c_form):          """          Get initial data from an object: formset @@ -1278,6 +1304,16 @@ class Wizard(NamedUrlWizardView):          if not through and not obj._meta.ordering:              query = query.order_by('pk') +        # an intermediary model is used +        through_fields = [] +        if through and not related.model._meta.auto_created: +            target_field = getattr(related.model, c_form.form.base_model, None) +            if target_field: +                related_model = target_field.field.related_model +                for field in related_model._meta.get_fields(): +                    through_fields.append(field.name) +                through_fields.append('pk') +          for child_obj in query.all():              if not keys:                  break @@ -1287,18 +1323,12 @@ class Wizard(NamedUrlWizardView):                  vals[keys[0]] = unicode(child_obj.pk)              else:                  for field in keys: -                    if hasattr(child_obj, field): -                        value = getattr(child_obj, field) -                        if hasattr(value, 'pk'): -                            value = value.pk -                        if value is not None: -                            vals[field] = unicode(value) -                    elif hasattr(child_obj, field + "s"): -                        # M2M -                        vals.setlist(field, [ -                            unicode(v.pk) -                            for v in getattr(child_obj, field + "s").all() -                        ]) +                    vals = self._get_vals_for_instanced_init_for_formset( +                        field, child_obj, vals) +                for field in through_fields: +                    related_obj = getattr(child_obj, c_form.form.base_model) +                    vals = self._get_vals_for_instanced_init_for_formset( +                        field, related_obj, vals)              if vals:                  initial.append(vals)          return initial | 
