Created
January 27, 2019 10:56
-
-
Save elfgzp/a18826f4da6a616be2286dfb8d4edead to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__author__ = 'gzp' | |
from django.core.management import call_command | |
from django.db import DEFAULT_DB_ALIAS, connections, transaction | |
from django.test.testcases import TransactionTestCase | |
class TestCase(TransactionTestCase): | |
""" | |
此类修复 Django TestCase 中由于使用了多数据库,但是 multi_db 并未指定多数据库,单元测试依然只是在一个数据库上运行。 | |
但是源码中的 connections_support_transactions 将所有数据库都包含进来了,导致在同时使用 MangoDB 和 MySQL 数据库时, | |
MySQL 数据库无法回滚,清空了所有的初始化数据,导致单元测试无法使用初始化的数据。 | |
""" | |
@classmethod | |
def _databases_support_transactions(cls): | |
return all( | |
conn.features.supports_transactions | |
for conn in connections.all() | |
if conn.alias in cls._databases_names() | |
) | |
@classmethod | |
def _databases_names(cls, include_mirrors=True): | |
# If the test case has a multi_db=True flag, act on all databases, | |
# including mirrors or not. Otherwise, just on the default DB. | |
if cls.multi_db: | |
return [ | |
alias for alias in connections | |
if include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR'] | |
] | |
else: | |
return [DEFAULT_DB_ALIAS] | |
@classmethod | |
def _enter_atomics(cls): | |
"""Open atomic blocks for multiple databases.""" | |
atomics = {} | |
for db_name in cls._databases_names(): | |
atomics[db_name] = transaction.atomic(using=db_name) | |
atomics[db_name].__enter__() | |
return atomics | |
@classmethod | |
def _rollback_atomics(cls, atomics): | |
"""Rollback atomic blocks opened by the previous method.""" | |
for db_name in reversed(cls._databases_names()): | |
transaction.set_rollback(True, using=db_name) | |
atomics[db_name].__exit__(None, None, None) | |
@classmethod | |
def setUpClass(cls): | |
super().setUpClass() | |
if not cls._databases_support_transactions(): | |
return | |
cls.cls_atomics = cls._enter_atomics() | |
if cls.fixtures: | |
for db_name in cls._databases_names(include_mirrors=False): | |
try: | |
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) | |
except Exception: | |
cls._rollback_atomics(cls.cls_atomics) | |
raise | |
try: | |
cls.setUpTestData() | |
except Exception: | |
cls._rollback_atomics(cls.cls_atomics) | |
raise | |
@classmethod | |
def tearDownClass(cls): | |
if cls._databases_support_transactions(): | |
cls._rollback_atomics(cls.cls_atomics) | |
for conn in connections.all(): | |
conn.close() | |
super().tearDownClass() | |
@classmethod | |
def setUpTestData(cls): | |
"""Load initial data for the TestCase.""" | |
pass | |
def _should_reload_connections(self): | |
if self._databases_support_transactions(): | |
return False | |
return super()._should_reload_connections() | |
def _fixture_setup(self): | |
if not self._databases_support_transactions(): | |
# If the backend does not support transactions, we should reload | |
# class data before each test | |
self.setUpTestData() | |
return super()._fixture_setup() | |
assert not self.reset_sequences, 'reset_sequences cannot be used on TestCase instances' | |
self.atomics = self._enter_atomics() | |
def _fixture_teardown(self): | |
if not self._databases_support_transactions(): | |
return super()._fixture_teardown() | |
try: | |
for db_name in reversed(self._databases_names()): | |
if self._should_check_constraints(connections[db_name]): | |
connections[db_name].check_constraints() | |
finally: | |
self._rollback_atomics(self.atomics) | |
def _should_check_constraints(self, connection): | |
return ( | |
connection.features.can_defer_constraint_checks and | |
not connection.needs_rollback and connection.is_usable() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment