Skip to content

Instantly share code, notes, and snippets.

@racitup
Last active July 8, 2021 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save racitup/a584ed2dc4583d61750f08e419ea26c8 to your computer and use it in GitHub Desktop.
Save racitup/a584ed2dc4583d61750f08e419ea26c8 to your computer and use it in GitHub Desktop.
Django db integrity check management command
import sys
import time
class ProgressBase(object):
"""
An abstract class that helps to put text on the screen and erase it again.
"""
def __init__(self):
self._str = ''
def _show(self, text):
sys.stderr.write('\b' * len(self._str))
self._str = text.ljust(len(self._str))
sys.stderr.write(self._str)
class with_spinner(ProgressBase):
"""
A spinner for long loops of unknown duration, written on stderr.
Wrap this around any iterable, for example:
for line in with_spinner(lines, action = 'Processing lines...')
"""
chars = '|/-\\'
def __init__(self, iterable = None, action = None, done = 'done'):
super(with_spinner, self).__init__()
self.iterable = iterable
self.frame = 0
self.last = time.time()
self.done = done
if action:
sys.stderr.write(action + ' ')
def update(self):
now = time.time()
if self.last + 0.5 < now:
self.last = now
self.frame = (self.frame + 1) % len(with_spinner.chars)
self._show(with_spinner.chars[self.frame])
def stop(self):
self._show(self.done or '')
sys.stderr.write('\n')
def __iter__(self):
for item in self.iterable:
yield item
self.update()
self.stop()
class with_progress_meter(ProgressBase):
"""
A progress meter for long loops of known length, written on stderr.
Wrap this around a list-like object, for example:
for line in with_progress_meter(lines, action = 'Processing lines...')
"""
def __init__(self, iterable = None, total = None, action = None, done = 'done'):
super(with_progress_meter, self).__init__()
self.iterable = iterable
if total is None:
total = len(self.iterable)
self.total = total
self.start_time = time.time()
self.last = self.start_time
self.at = 0
self.done = done
if action:
sys.stderr.write(action + ' ')
self._str = ''
def update(self, at):
self.at = at
now = time.time()
if self.last + 0.5 < now:
self.last = now
self._show(self._progress())
def stop(self):
self._show(self.done or '')
sys.stderr.write('\n')
def __iter__(self):
at = 0
for item in self.iterable:
yield item
at += 1
self.update(at)
self.stop()
def _progress(self):
text = '%3d%%' % int(self.at * 100 / self.total if self.total else 100)
if self.at > 0:
spent = time.time() - self.start_time
remaining = (self.total - self.at) * spent / self.at
text += ' (ETA: %d:%02d.%03d)' % (
int(remaining) / 60,
int(remaining) % 60,
int(remaining * 1000) % 1000)
return text
from django.core.management.base import BaseCommand
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.apps import apps
from ._progress import with_progress_meter
def model_name(model):
return '%s.%s' % (model._meta.app_label, model._meta.object_name)
class Command(BaseCommand):
"""Supports django 1.8 up"""
help = 'Checks constraints in the database and reports violations on stdout'
def add_arguments(self, parser):
parser.add_argument(
'-e',
'--exclude',
action='append',
type=str,
dest='exclude',
help="a model to exclude in the format 'app.model'"
)
def handle(self, *args, **options):
exclude = options.get('exclude', None) or []
failed_instance_count = 0
failed_model_count = 0
for modelclass in apps.get_models():
if model_name(modelclass) in exclude:
self.stdout.write('Skipped model %s' % model_name(modelclass))
continue
fail_count = self.check_model(modelclass)
if fail_count > 0:
failed_model_count += 1
failed_instance_count += fail_count
if failed_model_count:
self.stderr.write('Detected %d errors in %d models' % (failed_instance_count, failed_model_count))
exit(1)
else:
self.stdout.write('No errors found')
def check_model(self, model):
meta = model._meta
if meta.proxy:
self.stderr.write('WARNING: proxy models not currently supported; %s ignored' % model_name(model))
return 0
# Define all the checks we can do; they return True if they are ok,
# False if not (and print a message to stdout)
def check_foreign_key(model, field):
foreign_model = field.related_model
def check_instance(instance):
try:
# name: name of the attribute containing the model instance (e.g. 'user')
# attname: name of the attribute containing the id (e.g. 'user_id')
getattr(instance, field.name)
return True
except ObjectDoesNotExist:
self.stderr.write('%s with pk %s refers via field %s to nonexistent %s with pk %s' % \
(model_name(model), str(instance.pk), field.name, model_name(foreign_model), getattr(instance, field.attname)))
return check_instance
# Make a list of checks to run on each model instance
checks = []
for field in meta.local_fields + meta.local_many_to_many + meta.virtual_fields:
if isinstance(field, models.ForeignKey):
checks.append(check_foreign_key(model, field))
# Run all checks
fail_count = 0
if checks:
for instance in with_progress_meter(model.objects.all(), model.objects.count(), 'Checking model %s ...' % model_name(model)):
for check in checks:
if not check(instance):
fail_count += 1
return fail_count
@racitup
Copy link
Author

racitup commented Sep 28, 2017

Originally from stackoverflow but updated for Django 1.8+

@racitup
Copy link
Author

racitup commented Sep 28, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment