Skip to content

Instantly share code, notes, and snippets.

@dcramer
Last active December 18, 2015 15:09
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dcramer/5802587 to your computer and use it in GitHub Desktop.
Save dcramer/5802587 to your computer and use it in GitHub Desktop.
class EverythingCollector(Collector):
"""
More or less identical to the default Django collector except we always
return relations (even when they shouldnt matter).
"""
def collect(self, objs, source=None, nullable=False, collect_related=True,
source_attr=None, reverse_dependency=False):
new_objs = self.add(objs)
if not new_objs:
return
model = new_objs[0].__class__
# Recursively collect concrete model's parent models, but not their
# related objects. These will be found by meta.get_all_related_objects()
concrete_model = model._meta.concrete_model
for ptr in concrete_model._meta.parents.iteritems():
if ptr:
# FIXME: This seems to be buggy and execute a query for each
# parent object fetch. We have the parent data in the obj,
# but we don't have a nice way to turn that data into parent
# object instance.
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
self.collect(parent_objs, source=model,
source_attr=ptr.rel.related_name,
collect_related=False,
reverse_dependency=True)
if collect_related:
for related in model._meta.get_all_related_objects(
include_hidden=True, include_proxy_eq=True):
sub_objs = self.related_objects(related, new_objs)
self.add(sub_objs)
# TODO This entire block is only needed as a special case to
# support cascade-deletes for GenericRelation. It should be
# removed/fixed when the ORM gains a proper abstraction for virtual
# or composite fields, and GFKs are reworked to fit into that.
for relation in model._meta.many_to_many:
if not relation.rel.through:
sub_objs = relation.bulk_related_objects(new_objs, self.using)
self.collect(sub_objs,
source=model,
source_attr=relation.rel.related_name,
nullable=True)
def merge_into(self, other, callback=lambda x: x, using='default'):
"""
Collects objects related to ``self`` and updates their foreign keys to
point to ``other``.
If ``callback`` is specified, it will be executed on each collected chunk
before any changes are made, and should return a modified list of results
that still need updated.
NOTE: Duplicates (unique constraints) which exist and are bound to ``other``
are preserved, and relations on ``self`` are discarded.
"""
# TODO: proper support for database routing
s_model = type(self)
# Find all the objects than need to be deleted.
collector = EverythingCollector(using=using)
collector.collect([self])
for model, objects in collector.data.iteritems():
# find all potential keys which match our type
fields = set(
f.name for f in model._meta.fields
if isinstance(f, ForeignKey)
and f.rel.to == s_model
if f.rel.to
)
print model, objects, fields
if not fields:
# the collector pulls in the self reference, so if it's our model
# we actually assume it's probably not related to itself, and its
# perfectly ok
if model == s_model:
continue
raise TypeError('Unable to determine related keys on %r' % model)
for obj in objects:
send_signals = not model._meta.auto_created
# find fields which need changed
update_kwargs = {}
for f_name in fields:
if getattr(obj, f_name) == self:
update_kwargs[f_name] = other
if not update_kwargs:
# as before, if we're referencing ourself, this is ok
if obj == self:
continue
raise ValueError('Mismatched row present in related results')
signal_kwargs = {
'sender': model,
'instance': obj,
'using': using,
'migrated': True,
}
if send_signals:
pre_delete.send(**signal_kwargs)
post_delete.send(**signal_kwargs)
for k, v in update_kwargs.iteritems():
setattr(obj, k, v)
if send_signals:
pre_save.send(created=True, **signal_kwargs)
sid = transaction.savepoint(using=using)
try:
model.objects.filter(pk=obj.pk).update(**update_kwargs)
except IntegrityError:
# duplicate key exists, destroy the relations
transaction.savepoint_rollback(sid, using=using)
model.objects.filter(pk=obj.pk).delete()
else:
transaction.savepoint_commit(sid, using=using)
if send_signals:
post_save.send(created=True, **signal_kwargs)
class MergeIntoTest(TestCase):
def test_all_the_things(self):
user_1 = User.objects.create(username='original')
user_2 = User.objects.create(username='new')
team_1 = Team.objects.create(owner=user_1)
team_2 = Team.objects.create(owner=user_2)
project_1 = Project.objects.create(owner=user_1, team=team_1)
project_2 = Project.objects.create(owner=user_2, team=team_2)
ag = AccessGroup.objects.create(team=team_2)
ag.members.add(user_1)
ag.members.add(user_2)
merge_into(user_1, user_2)
assert Team.objects.get(id=team_1.id).owner == user_2
assert Team.objects.get(id=team_2.id).owner == user_2
assert Project.objects.get(id=project_1.id).owner == user_2
assert Project.objects.get(id=project_2.id).owner == user_2
assert list(ag.members.all()) == [user_2]
# make sure we didnt remove the instance
assert User.objects.filter(id=user_1.id).exists()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment