Skip to content

Instantly share code, notes, and snippets.

@myaser
Last active April 28, 2022 23:31
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save myaser/7689869 to your computer and use it in GitHub Desktop.
Save myaser/7689869 to your computer and use it in GitHub Desktop.

django-test-abstract-models

a TestCase subclass to help you test your abstract django models

usage:

1- subclass your django abstract models

2- write your test case like this:

class MyTestCase(AbstractModelTestCase):
    self.models = [MyAbstractModelSubClass, .....]
    # your tests goes here ...

3- if you didn't provide self.models attribute it will search the current app for models in the path myapp.tests.models.*

from django.db import connection
from django.test import TestCase
from django.core.management.color import no_style
from importlib import import_module
def sync_models(model_list):
'''
Create the database tables for given models. used for testing abstract models
'''
tables = connection.introspection.table_names()
seen_models = connection.introspection.installed_models(tables)
created_models = set()
pending_references = {}
cursor = connection.cursor()
for model in model_list:
# Create the model's database table, if it doesn't already exist.
sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
seen_models.add(model)
created_models.add(model)
for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models:
sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
for statement in sql:
cursor.execute(statement)
tables.append(connection.introspection.table_name_converter(model._meta.db_table))
def get_test_models(package):
try:
models_module = import_module('.tests.models', package)
return [getattr(models_module, _dir) for _dir in dir(models_module)
if _dir.startswith('Test')]
except:
return []
class AbstractModelTestCase(TestCase):
def setUp(self):
TestCase.setUp(self)
if not self.models:
self.models = get_test_models(self.__module__.partition('.tests')[0])
sync_models(self.models)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment