Last active
July 8, 2021 10:25
-
-
Save racitup/a584ed2dc4583d61750f08e419ea26c8 to your computer and use it in GitHub Desktop.
Django db integrity check management command
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import time | |
class ProgressBase(object): | |
""" | |
An abstract class that helps to put text on the screen and erase it again. | |
""" | |
def __init__(self): | |
self._str = '' | |
def _show(self, text): | |
sys.stderr.write('\b' * len(self._str)) | |
self._str = text.ljust(len(self._str)) | |
sys.stderr.write(self._str) | |
class with_spinner(ProgressBase): | |
""" | |
A spinner for long loops of unknown duration, written on stderr. | |
Wrap this around any iterable, for example: | |
for line in with_spinner(lines, action = 'Processing lines...') | |
""" | |
chars = '|/-\\' | |
def __init__(self, iterable = None, action = None, done = 'done'): | |
super(with_spinner, self).__init__() | |
self.iterable = iterable | |
self.frame = 0 | |
self.last = time.time() | |
self.done = done | |
if action: | |
sys.stderr.write(action + ' ') | |
def update(self): | |
now = time.time() | |
if self.last + 0.5 < now: | |
self.last = now | |
self.frame = (self.frame + 1) % len(with_spinner.chars) | |
self._show(with_spinner.chars[self.frame]) | |
def stop(self): | |
self._show(self.done or '') | |
sys.stderr.write('\n') | |
def __iter__(self): | |
for item in self.iterable: | |
yield item | |
self.update() | |
self.stop() | |
class with_progress_meter(ProgressBase): | |
""" | |
A progress meter for long loops of known length, written on stderr. | |
Wrap this around a list-like object, for example: | |
for line in with_progress_meter(lines, action = 'Processing lines...') | |
""" | |
def __init__(self, iterable = None, total = None, action = None, done = 'done'): | |
super(with_progress_meter, self).__init__() | |
self.iterable = iterable | |
if total is None: | |
total = len(self.iterable) | |
self.total = total | |
self.start_time = time.time() | |
self.last = self.start_time | |
self.at = 0 | |
self.done = done | |
if action: | |
sys.stderr.write(action + ' ') | |
self._str = '' | |
def update(self, at): | |
self.at = at | |
now = time.time() | |
if self.last + 0.5 < now: | |
self.last = now | |
self._show(self._progress()) | |
def stop(self): | |
self._show(self.done or '') | |
sys.stderr.write('\n') | |
def __iter__(self): | |
at = 0 | |
for item in self.iterable: | |
yield item | |
at += 1 | |
self.update(at) | |
self.stop() | |
def _progress(self): | |
text = '%3d%%' % int(self.at * 100 / self.total if self.total else 100) | |
if self.at > 0: | |
spent = time.time() - self.start_time | |
remaining = (self.total - self.at) * spent / self.at | |
text += ' (ETA: %d:%02d.%03d)' % ( | |
int(remaining) / 60, | |
int(remaining) % 60, | |
int(remaining * 1000) % 1000) | |
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.core.management.base import BaseCommand | |
from django.core.exceptions import ObjectDoesNotExist | |
from django.db import models | |
from django.apps import apps | |
from ._progress import with_progress_meter | |
def model_name(model): | |
return '%s.%s' % (model._meta.app_label, model._meta.object_name) | |
class Command(BaseCommand): | |
"""Supports django 1.8 up""" | |
help = 'Checks constraints in the database and reports violations on stdout' | |
def add_arguments(self, parser): | |
parser.add_argument( | |
'-e', | |
'--exclude', | |
action='append', | |
type=str, | |
dest='exclude', | |
help="a model to exclude in the format 'app.model'" | |
) | |
def handle(self, *args, **options): | |
exclude = options.get('exclude', None) or [] | |
failed_instance_count = 0 | |
failed_model_count = 0 | |
for modelclass in apps.get_models(): | |
if model_name(modelclass) in exclude: | |
self.stdout.write('Skipped model %s' % model_name(modelclass)) | |
continue | |
fail_count = self.check_model(modelclass) | |
if fail_count > 0: | |
failed_model_count += 1 | |
failed_instance_count += fail_count | |
if failed_model_count: | |
self.stderr.write('Detected %d errors in %d models' % (failed_instance_count, failed_model_count)) | |
exit(1) | |
else: | |
self.stdout.write('No errors found') | |
def check_model(self, model): | |
meta = model._meta | |
if meta.proxy: | |
self.stderr.write('WARNING: proxy models not currently supported; %s ignored' % model_name(model)) | |
return 0 | |
# Define all the checks we can do; they return True if they are ok, | |
# False if not (and print a message to stdout) | |
def check_foreign_key(model, field): | |
foreign_model = field.related_model | |
def check_instance(instance): | |
try: | |
# name: name of the attribute containing the model instance (e.g. 'user') | |
# attname: name of the attribute containing the id (e.g. 'user_id') | |
getattr(instance, field.name) | |
return True | |
except ObjectDoesNotExist: | |
self.stderr.write('%s with pk %s refers via field %s to nonexistent %s with pk %s' % \ | |
(model_name(model), str(instance.pk), field.name, model_name(foreign_model), getattr(instance, field.attname))) | |
return check_instance | |
# Make a list of checks to run on each model instance | |
checks = [] | |
for field in meta.local_fields + meta.local_many_to_many + meta.virtual_fields: | |
if isinstance(field, models.ForeignKey): | |
checks.append(check_foreign_key(model, field)) | |
# Run all checks | |
fail_count = 0 | |
if checks: | |
for instance in with_progress_meter(model.objects.all(), model.objects.count(), 'Checking model %s ...' % model_name(model)): | |
for check in checks: | |
if not check(instance): | |
fail_count += 1 | |
return fail_count |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Originally from stackoverflow but updated for Django 1.8+