Skip to content

Instantly share code, notes, and snippets.

@TobeTek
Created March 2, 2023 22:48
Show Gist options
  • Save TobeTek/e6214cebcf138f1127a1a64a4d1fa494 to your computer and use it in GitHub Desktop.
Save TobeTek/e6214cebcf138f1127a1a64a4d1fa494 to your computer and use it in GitHub Desktop.
A cleaner approach to mocking unmanaged models in Django tests
"""
A cleaner approach to temporarily creating unmanaged model db tables for tests
"""
from unittest import TestCase
from django.db import connections, models
class create_unmanaged_model_tables:
"""
Create db tables for unmanaged models for tests
Adapted from: https://stackoverflow.com/a/49800437
Examples:
with create_unmanaged_model_tables(UnmanagedModel):
...
@create_unmanaged_model_tables(UnmanagedModel, FooModel)
def test_generate_data():
...
@create_unmanaged_model_tables(UnmanagedModel, FooModel)
def MyTestCase(unittest.TestCase):
...
"""
def __init__(self, unmanaged_models: list[ModelBase], db_alias: str = "default"):
"""
:param str db_alias: Name of the database to connect to, defaults to "default"
"""
self.unmanaged_models = unmanaged_models
self.db_alias = db_alias
self.connection = connections[db_alias]
def __call__(self, obj):
if issubclass(obj, TestCase):
return self.decorate_class(obj)
return self.decorate_callable(obj)
def __enter__(self):
self.start()
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
def start(self):
with self.connection.schema_editor() as schema_editor:
for model in self.unmanaged_models:
schema_editor.create_model(model)
if (
model._meta.db_table
not in self.connection.introspection.table_names()
):
raise ValueError(
"Table `{table_name}` is missing in test database.".format(
table_name=model._meta.db_table
)
)
def stop(self):
with self.connection.schema_editor() as schema_editor:
for model in self.unmanaged_models:
schema_editor.delete_model(model)
def copy(self):
return self.__class__(
unmanaged_models=self.unmanaged_models, db_alias=self.db_alias
)
def decorate_class(self, klass):
# Modify setUpClass and tearDownClass
orig_setUpClass = klass.setUpClass
orig_tearDownClass = klass.tearDownClass
# noinspection PyDecorator
@classmethod
def setUpClass(cls):
self.start()
if orig_setUpClass is not None:
orig_setUpClass()
self.stop()
# noinspection PyDecorator
@classmethod
def tearDownClass(cls):
self.start()
if orig_tearDownClass is not None:
orig_tearDownClass()
self.stop()
klass.setUpClass = setUpClass
klass.tearDownClass = tearDownClass
orig_setUp = klass.setUp
orig_tearDown = klass.tearDown
def setUp(*args, **kwargs):
self.start()
if orig_setUp is not None:
orig_setUp(*args, **kwargs)
def tearDown(*args, **kwargs):
if orig_tearDown is not None:
orig_tearDown(*args, **kwargs)
self.stop()
klass.setUp = setUp
klass.tearDown = tearDown
return klass
def decorate_callable(self, callable_obj):
@functools.wraps(callable_obj)
def wrapper(*args, **kwargs):
with self.copy():
return callable_obj(*args, **kwargs)
return wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment