Skip to content

Instantly share code, notes, and snippets.

@sstoops
Created April 10, 2019 15:54
Show Gist options
  • Save sstoops/5ae85c818baade40f02b2186763a2790 to your computer and use it in GitHub Desktop.
Save sstoops/5ae85c818baade40f02b2186763a2790 to your computer and use it in GitHub Desktop.
# fmt: black
"""Transfer an object and its reverse-dependants from one database to another.
This script is useful for "undeleting" objects. Let's say a user with the username of
robc was mistakenly deleted in the Django admin (or by a query like obj.delete()) and
all reverse-related objects were deleted along with it. This may include all photos
since Photo.photographer is a FK to User, then on to stories since Story.teaser_photo
is a FK to Photo. As you can imagine, leading to quite the catastrophy.
If you can get your hands on a copy of the database before the deletion occurred, load
that snapshot into a secondary database, configure Django to use both your existing
database and this new database (using the multiple-database configuration in settings),
then use this script to manually move the objects across. Sample database configuration
is included below.
This script accepts a number of paramaters to query for an object through a Django
model. Detailed usage is available with the --help flag on the command line.
Examples::
python django_object_transfer.py
--model=auth.user
--username=robc
--src=archive
--dest=default
Occasionally, you may run into circular dependency issues between multiple objects
that were deleted together. You can use the --exclude and --include arguments to
alleviate some of this pain by deferring the transfer of certain data types until
you load up the parent objects.
python django_object_transfer.py
--model=auth.user
--username=robc
--src=archive
--dest=default
--exclude=comments.CommentFlag
This command will transfer everything except the comment.CommentFlag objects. Once
you run this command, and presumably others to resolve the dependency issue, you
can then use the --include to only transfer the comments.CommentFlag objects.
python django_object_transfer.py
--model=auth.user
--username=robc
--src=archive
--dest=default
--include=comments.CommentFlag
You can also use the built in --pks flag to import several objects of the same type.
python django_object_transfer.py
--model=stories.Story
--pks 1234 1235 1236
--src=archive
--dest=default
Sample settings.DATABASES configuration::
DATABASES = {
'default': {
'NAME': 'db-foo',
'ENGINE': 'django.contrib.gis.db.backends.postgis',
'USER': 'apache',
'HOST': '',
},
# This DB is the snapshot from before the deletion
'archive': {
'NAME': 'db-bar',
'ENGINE': 'django.contrib.gis.db.backends.postgis',
'USER': 'apache',
'HOST': '',
}
}
"""
import argparse
import sys
from collections import Counter
from operator import attrgetter
import django
django.setup()
from cannonball.core.utils.log_compat import getLogger
from django.apps import apps
from django.db import transaction
from django.db.models.deletion import Collector
log = getLogger("django-object-transfer")
class DjangoUndeleter(Collector):
def delete(self):
# Killing this method so if anyone is every goofing with this class in the
# future, they don't accidentally delete all the records from the archive
# database.
raise NotImplementedError
def load(self, dst_db, include, exclude):
# sort instance collections
for model, instances in self.data.items():
self.data[model] = sorted(instances, key=attrgetter("pk"))
self.sort()
# number of objects loaded for each model label
load_counter = Counter()
with transaction.atomic(using=dst_db, savepoint=False):
# delete instances
log.info(u"Performing inserts of objects with cascading relationships.")
for model, instances in reversed(self.data.items()):
model_label = "{}.{}".format(
model._meta.app_label, model._meta.object_name
)
if include and model_label not in include:
continue
if not include and exclude and model_label in exclude:
continue
log.info(
" Inserting %s %s objects", len(instances), model._meta.model_name
)
if model._meta.model_name == "story":
for obj in instances:
obj.show_ads = True
dst_manager = model.objects.using(dst_db)
count = len(dst_manager.bulk_create(instances, batch_size=5000))
load_counter[model._meta.model_name] += count
log.info(u"Done.")
# fast inserts
log.info(
"Performing fast inserts, those objects without any cascading "
"relationships"
)
# Deduplicate all of the content types using sets
fast_inserts = {}
for queryset in self.fast_deletes:
if not queryset.exists():
continue
model = queryset.first()._meta.model
model_label = "{}.{}".format(
model._meta.app_label, model._meta.object_name
)
if include and model_label not in include:
continue
if not include and exclude and model_label in exclude:
continue
fi = fast_inserts[model] = fast_inserts.get(model, set())
fi.update(list(queryset))
for model, obj_list in fast_inserts.items():
if not obj_list:
continue
log.info(
" Inserting %s %s objects", len(obj_list), model._meta.model_name
)
dst_manager = model.objects.using(dst_db)
count = len(dst_manager.bulk_create(obj_list, batch_size=5000))
load_counter[model._meta.model_name] += count
log.info(u"Done.")
return sum(load_counter.values()), dict(load_counter)
def main(
model_label, query_filter, source_db, destination_db, include=None, exclude=None
):
log.info("Gathering objects from source database: %s", source_db)
model = apps.get_model(model_label)
log.info(
"Using Django model: %s.%s", model._meta.app_label, model._meta.object_name
)
log.info(u"Querying with argument: %s=%s ...", *query_filter.items()[0])
objs = list(model.objects.using(source_db).filter(**query_filter))
log.info(u" Found objects: %s", repr(objs))
log.info(u"Collecting all cascading objects ...")
collector = DjangoUndeleter(using=source_db)
collector.collect(objs)
log.info(u"Done.")
log.info(
u"Now attempting to reinsert objects into destination database: %s",
destination_db,
)
stats = collector.load(dst_db=destination_db, include=include, exclude=exclude)
log.info(u"-" * 80)
log.info(u"%s objects loaded", stats[0])
for label, count in sorted(stats[1].items()):
log.info(u"%-6s %s", count, label)
log.info(u"-" * 80)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Serialize a Django model object. Any extra \
command arguments are converted to query filter (eg. --slug=foo-bar)."
)
parser.add_argument(
"-m", "--model", help="Django model. eg. auth.User", required=True
)
parser.add_argument(
"-s", "--src", help="The source database", required=True
)
parser.add_argument("-p", "--pks", nargs="*", help="Primary keys.")
parser.add_argument(
"-d",
"--dest",
help="The destination database",
required=True,
)
parser.add_argument(
"-i",
"--include",
nargs="*",
help="Include ONLY these models in migration (eg. auth.User).",
)
parser.add_argument(
"-e",
"--exclude",
nargs="*",
help="Exclude these models from migration (eg. auth.User). Ignored if "
"--include given.",
)
args, unknown = parser.parse_known_args()
if not args.pks and not unknown:
print(
u"You must pass additional query filter arguments. The resulting query "
u"must return only one object. eg. --slug=foo-bar or --id=5"
)
sys.exit(1)
if args.pks:
query_filter = {"pk__in": args.pks}
else:
query_filter = dict([x.strip("-").split("=") for x in unknown])
main(
model_label=args.model,
query_filter=query_filter,
source_db=args.src,
destination_db=args.dest,
include=args.include,
exclude=args.exclude,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment