From e4bf9fd5b98b25bb93fd3bef64c3b511851e6f73 Mon Sep 17 00:00:00 2001 From: Étienne Loks Date: Thu, 26 Apr 2018 13:23:45 +0200 Subject: Wizard: manage keep M2M member for formset with intermediary model --- ishtar_common/wizards.py | 56 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 13 deletions(-) (limited to 'ishtar_common/wizards.py') diff --git a/ishtar_common/wizards.py b/ishtar_common/wizards.py index 70f3caa19..89a155c1c 100644 --- a/ishtar_common/wizards.py +++ b/ishtar_common/wizards.py @@ -835,7 +835,17 @@ class Wizard(NamedUrlWizardView): value, created = model.objects.get_or_create( **value) else: - value = model.objects.create(**value) + if 'pk' in value and value['pk']: + try: + instance = model.objects.get( + pk=value.pop('pk')) + except model.DoesNotExist: + continue + for k in value: + setattr(instance, k, value[k]) + value = instance + else: + value = model.objects.create(**value) value.save() # force post_save # check that an item is not add multiple times (forged forms) if value not in related_model.all() and \ @@ -1252,6 +1262,22 @@ class Wizard(NamedUrlWizardView): initial[base_field] = unicode(value) return initial + @staticmethod + def _get_vals_for_instanced_init_for_formset(field, child_obj, vals): + if hasattr(child_obj, field): + value = getattr(child_obj, field) + if hasattr(value, 'pk'): + value = value.pk + if value is not None: + vals[field] = unicode(value) + elif hasattr(child_obj, field + "s"): + # M2M + vals.setlist(field, [ + unicode(v.pk) + for v in getattr(child_obj, field + "s").all() + ]) + return vals + def _get_instanced_init_for_formset(self, obj, current_step, c_form): """ Get initial data from an object: formset @@ -1278,6 +1304,16 @@ class Wizard(NamedUrlWizardView): if not through and not obj._meta.ordering: query = query.order_by('pk') + # an intermediary model is used + through_fields = [] + if through and not related.model._meta.auto_created: + target_field = getattr(related.model, c_form.form.base_model, None) + if target_field: + related_model = target_field.field.related_model + for field in related_model._meta.get_fields(): + through_fields.append(field.name) + through_fields.append('pk') + for child_obj in query.all(): if not keys: break @@ -1287,18 +1323,12 @@ class Wizard(NamedUrlWizardView): vals[keys[0]] = unicode(child_obj.pk) else: for field in keys: - if hasattr(child_obj, field): - value = getattr(child_obj, field) - if hasattr(value, 'pk'): - value = value.pk - if value is not None: - vals[field] = unicode(value) - elif hasattr(child_obj, field + "s"): - # M2M - vals.setlist(field, [ - unicode(v.pk) - for v in getattr(child_obj, field + "s").all() - ]) + vals = self._get_vals_for_instanced_init_for_formset( + field, child_obj, vals) + for field in through_fields: + related_obj = getattr(child_obj, c_form.form.base_model) + vals = self._get_vals_for_instanced_init_for_formset( + field, related_obj, vals) if vals: initial.append(vals) return initial -- cgit v1.2.3