diff options
Diffstat (limited to 'ishtar_common/model_merging.py')
| -rw-r--r-- | ishtar_common/model_merging.py | 35 | 
1 files changed, 21 insertions, 14 deletions
diff --git a/ishtar_common/model_merging.py b/ishtar_common/model_merging.py index b8c145fcb..c577a8cf1 100644 --- a/ishtar_common/model_merging.py +++ b/ishtar_common/model_merging.py @@ -4,6 +4,7 @@ from django.db import transaction  from django.db.models import get_models, Model  from django.contrib.contenttypes.generic import GenericForeignKey +  @transaction.commit_on_success  def merge_model_objects(primary_object, alias_objects=[], keep_old=False):      """ @@ -34,14 +35,13 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):              raise TypeError('Only models of same class can be merged')      # Get a list of all GenericForeignKeys in all models -    # TODO: this is a bit of a hack, since the generics framework should provide -    # a similar +    # TODO: this is a bit of a hack, since the generics framework should +    # provide a similar      # method to the ForeignKey field for accessing the generic related fields.      generic_fields = []      for model in get_models(): -        for field_name, field in filter(lambda x: isinstance(x[1], -                                                             GenericForeignKey), -                                        model.__dict__.iteritems()): +        for field_name, field in filter(lambda x: isinstance( +                x[1], GenericForeignKey), model.__dict__.iteritems()):              generic_fields.append(field)      blank_local_fields = set() @@ -53,9 +53,11 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):          if value in [None, '']:              blank_local_fields.add(field.attname) -    # Loop through all alias objects and migrate their data to the primary object. +    # Loop through all alias objects and migrate their data to the primary +    # object.      for alias_object in alias_objects: -        # Migrate all foreign key references from alias object to primary object. +        # Migrate all foreign key references from alias object to primary +        # object.          for related_object in alias_object._meta.get_all_related_objects():              # The variable name on the alias_object model.              alias_varname = related_object.get_accessor_name() @@ -66,9 +68,10 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):                  setattr(obj, obj_varname, primary_object)                  obj.save() -        # Migrate all many to many references from alias object to primary object. +        # Migrate all many to many references from alias object to primary +        # object.          related_many_objects = \ -                       alias_object._meta.get_all_related_many_to_many_objects() +            alias_object._meta.get_all_related_many_to_many_objects()          related_many_object_names = set()          for related_many_object in related_many_objects:              alias_varname = related_many_object.get_accessor_name() @@ -78,7 +81,8 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):              if alias_varname is not None:                  # standard case -                related_many_objects = getattr(alias_object, alias_varname).all() +                related_many_objects = getattr( +                    alias_object, alias_varname).all()                  related_many_object_names.add(alias_varname)              else:                  # special case, symmetrical relation, no reverse accessor @@ -88,7 +92,8 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):                  getattr(obj, obj_varname).remove(alias_object)                  getattr(obj, obj_varname).add(primary_object) -        # Migrate local many to many references from alias object to primary object. +        # Migrate local many to many references from alias object to primary +        # object.          for many_to_many_object in alias_object._meta.many_to_many:              alias_varname = many_to_many_object.get_attname()              if alias_varname in related_many_object_names or \ @@ -107,13 +112,15 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):          for field in generic_fields:              filter_kwargs = {}              filter_kwargs[field.fk_field] = alias_object._get_pk_val() -            filter_kwargs[field.ct_field] = field.get_content_type(alias_object) +            filter_kwargs[field.ct_field] = field.get_content_type( +                alias_object)              for generic_related_object in field.model.objects.filter( -                                                               **filter_kwargs): +                    **filter_kwargs):                  setattr(generic_related_object, field.name, primary_object)                  generic_related_object.save() -        # Try to fill all missing values in primary object by values of duplicates +        # Try to fill all missing values in primary object by values of +        # duplicates          filled_up = set()          for field_name in blank_local_fields:              val = getattr(alias_object, field_name)  | 
