Skip to content

Instantly share code, notes, and snippets.

@elfgzp
Created January 17, 2019 07:34
Show Gist options
  • Save elfgzp/4b74a4b0860f3cecba17c56979e8e116 to your computer and use it in GitHub Desktop.
Save elfgzp/4b74a4b0860f3cecba17c56979e8e116 to your computer and use it in GitHub Desktop.
class TestCase(TransactionTestCase):
"""
Similar to TransactionTestCase, but use `transaction.atomic()` to achieve
test isolation.
In most situations, TestCase should be preferred to TransactionTestCase as
it allows faster execution. However, there are some situations where using
TransactionTestCase might be necessary (e.g. testing some transactional
behavior).
On database backends with no transaction support, TestCase behaves as
TransactionTestCase.
"""
@classmethod
def _databases_support_transactions(cls):
return all(
conn.features.supports_transactions
for conn in connections.all()
if conn.alias in cls._databases_names()
)
@classmethod
def _databases_names(cls, include_mirrors=True):
# If the test case has a multi_db=True flag, act on all databases,
# including mirrors or not. Otherwise, just on the default DB.
if cls.multi_db:
return [
alias for alias in connections
if include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
]
else:
return [DEFAULT_DB_ALIAS]
@classmethod
def _enter_atomics(cls):
"""Open atomic blocks for multiple databases."""
atomics = {}
for db_name in cls._databases_names():
atomics[db_name] = transaction.atomic(using=db_name)
atomics[db_name].__enter__()
return atomics
@classmethod
def _rollback_atomics(cls, atomics):
"""Rollback atomic blocks opened by the previous method."""
for db_name in reversed(cls._databases_names()):
transaction.set_rollback(True, using=db_name)
atomics[db_name].__exit__(None, None, None)
@classmethod
def setUpClass(cls):
super().setUpClass()
if not cls._databases_support_transactions():
return
cls.cls_atomics = cls._enter_atomics()
if cls.fixtures:
for db_name in cls._databases_names(include_mirrors=False):
try:
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
except Exception:
cls._rollback_atomics(cls.cls_atomics)
raise
try:
cls.setUpTestData()
except Exception:
cls._rollback_atomics(cls.cls_atomics)
raise
@classmethod
def tearDownClass(cls):
if cls._databases_support_transactions():
cls._rollback_atomics(cls.cls_atomics)
for conn in connections.all():
conn.close()
super().tearDownClass()
@classmethod
def setUpTestData(cls):
"""Load initial data for the TestCase."""
pass
def _should_reload_connections(self):
if self._databases_support_transactions():
return False
return super()._should_reload_connections()
def _fixture_setup(self):
if not self._databases_support_transactions():
# If the backend does not support transactions, we should reload
# class data before each test
self.setUpTestData()
return super()._fixture_setup()
assert not self.reset_sequences, 'reset_sequences cannot be used on TestCase instances'
self.atomics = self._enter_atomics()
def _fixture_teardown(self):
if not self._databases_support_transactions():
return super()._fixture_teardown()
try:
for db_name in reversed(self._databases_names()):
if self._should_check_constraints(connections[db_name]):
connections[db_name].check_constraints()
finally:
self._rollback_atomics(self.atomics)
def _should_check_constraints(self, connection):
return (
connection.features.can_defer_constraint_checks and
not connection.needs_rollback and connection.is_usable()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment