Skip to content

Instantly share code, notes, and snippets.

@btimby
Last active August 24, 2021 14:35
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save btimby/5811298 to your computer and use it in GitHub Desktop.
Save btimby/5811298 to your computer and use it in GitHub Desktop.
Use a Django database router, a TestCase mixin and thread local storage to allow unit tests to switch databases.
# Detect if executed under test
TESTING = any(test in sys.argv for test in (
'test', 'csslint', 'jenkins', 'jslint',
'jtest', 'lettuce', 'pep8', 'pyflakes',
'pylint', 'sloccount',
))
if TESTING:
# If testing, move the default DB to 'mysql' and replace it
# with a SQLite DB.
DATABASES['mysql'] = DATABASES['default']
DATABASES['default'] = {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:',
}
# Install our router so that unit tests can choose a DB.
DATABASE_ROUTERS = ('myapp.tests.TestUsingDbRouter', )
import threading
from django.db import DEFAULT_DB_ALIAS
from django.test.testcases import TestCase
_LOCALS = threading.local()
def set_test_db(db_name):
"Sets the database name to route to."
setattr(_LOCALS, 'test_db_name', db_name)
def get_test_db():
"Get the current database name or the default."
return getattr(_LOCALS, 'test_db_name', DEFAULT_DB_ALIAS)
def del_test_db():
"Clear the database name (restore default)"
try:
delattr(_LOCALS, 'test_db_name')
except AttributeError:
pass
class TestUsingDbRouter(object):
"Simple router to allow DB selection by name."
def db_for_read(self, model, **kwargs):
return get_test_db()
def db_for_write(self, model, **kwargs):
return get_test_db()
class UsingDbMixin(object):
"A mixin to allow a TestCase to select the DB to use."
multi_db = True
using_db = None
def setUp(self, *args, **kwargs):
super(UsingDbMixin, self).setUp(*args, **kwargs)
set_test_db(self.using_db)
def tearDown(self, *args, **kwargs):
del_test_db()
super(UsingDbMixin, self).tearDown(*args, **kwargs)
class MySQLTestCase(UsingDbMixin, TestCase):
"A unit test to run against the 'mysql' database."
using_db = 'mysql'
def test_mysql_something(self):
pass # TODO: test something specific to MySQL
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment