Skip to content

Instantly share code, notes, and snippets.

@rctay
Created June 11, 2010 16:44
Show Gist options
  • Save rctay/434735 to your computer and use it in GitHub Desktop.
Save rctay/434735 to your computer and use it in GitHub Desktop.
[django 1.2, 1.3] using nested transactions for testing

Description

Load fixtures on a per-TestCase basis, instead of per-test method, giving a nice performance boost.

This is done by grouping tests per-TestCase. Before any tests of the TestCase runs, the TestCase's fixtures are loaded (as usual); a transaction savepoint is then created before each test method runs, and rolled back to after the method returns, thus preserving fixture data between test method runs.

Usage

  • Unpack the files into your 'lib'; call it django_nested_txns for this example.

  • In your settings.py:

    TEST_RUNNER = 'django_nested_txns.runner.TestSuiteRunner'
    
  • In your tests.py:

    from django_nested_txns.cases import TestCase
    ...
    
    class MyTestCase(TestCase):
      ...
    

Compatibility

Tested on django 1.3. Should work on django 1.2.

from django.db import (
connections,
transaction,
DEFAULT_DB_ALIAS,
)
from django import test
from django.test.testcases import (
connections_support_transactions,
)
class TestCase(test.TestCase):
def _fixture_setup(self):
if not connections_support_transactions():
return super(TestCase, self)._fixture_setup()
if getattr(self, 'multi_db', False):
databases = connections
else:
databases = [DEFAULT_DB_ALIAS]
if self._state.is_first:
super(TestCase, self)._fixture_setup()
for db in databases:
rollback = transaction.savepoint_rollback
self._state.savepoint_rollbacks.append((db, transaction.savepoint(using=db)))
def _fixture_teardown(self):
if not connections_support_transactions():
return super(TestCase, self)._fixture_teardown()
for db, savepoint in self._state.savepoint_rollbacks.pop_iter():
transaction.savepoint_rollback(savepoint, using=db)
class NeedMore(Exception):
pass
class Node(object):
"""Internal representation of Trie nodes."""
__slots__ = 'parent key nodes value'.split()
no_value = object()
def __init__(self, parent, key, nodes, value):
self.parent = parent
self.key = key
self.nodes = nodes
self.value = value
@property
def keypath(self):
n = self
keypath = [n.key for n in iter(lambda: n.parent, None) if n.key]
keypath.reverse()
keypath.append(self.key)
return keypath
def walk(self):
nodes = [self]
while nodes:
node = nodes.pop()
if node.value is not node.no_value:
yield node
nodes.extend(node.nodes[key] for key in sorted(node.nodes, reverse=True))
class Trie(object):
"""A simple prefix tree (trie) implementation.
If attempting to access a node without a value, but with descendents,
NeedMore will be raised. If there are no descendents, KeyError will be
raised.
Usage:
>>> import trie
>>> from pprint import pprint
>>> t = trie.Trie()
>>> t['foobaz'] = 'Here is a foobaz.'
>>> t['foobar'] = 'This is a foobar.'
>>> t['fooqat'] = "What's a fooqat?"
>>> pprint(list(t))
[['f', 'o', 'o', 'b', 'a', 'r'],
['f', 'o', 'o', 'b', 'a', 'z'],
['f', 'o', 'o', 'q', 'a', 't']]
>>> pprint(list(t.iteritems()))
[(['f', 'o', 'o', 'b', 'a', 'r'], 'This is a foobar.'),
(['f', 'o', 'o', 'b', 'a', 'z'], 'Here is a foobaz.'),
(['f', 'o', 'o', 'q', 'a', 't'], "What's a fooqat?")]
>>> t['foo']
Traceback (most recent call last):
...
NeedMore
>>> t['fooqux']
Traceback (most recent call last):
...
KeyError: 'fooqux'
>>> t.children('fooba')
{'r': 'This is a foobar.', 'z': 'Here is a foobaz.'}
>>> del t['foobaz']
>>> pprint(list(t.iteritems()))
[(['f', 'o', 'o', 'b', 'a', 'r'], 'This is a foobar.'),
(['f', 'o', 'o', 'q', 'a', 't'], "What's a fooqat?")]
"""
def __init__(self, root_data=Node.no_value, mapping=()):
"""Initialize a Trie instance.
Args (both optional):
root_data: value of the root node (ie. Trie('hello')[()] == 'hello').
mapping: a sequence of (key, value) pairs to initialize with.
"""
self.root = Node(None, None, {}, root_data)
self.extend(mapping)
def extend(self, mapping):
"""Update the Trie with a sequence of (key, value) pairs."""
for k, v in mapping:
self[k] = v
def __setitem__(self, k, v):
n = self.root
for c in k:
n = n.nodes.setdefault(c, Node(n, c, {}, Node.no_value))
n.value = v
def _getnode(self, k):
n = self.root
for c in k:
try:
n = n.nodes[c]
except KeyError:
raise KeyError(k)
return n
def __getitem__(self, k):
n = self._getnode(k)
if n.value is Node.no_value:
if n.nodes:
raise NeedMore()
else:
raise KeyError(k)
return n.value
def __delitem__(self, k):
n = self._getnode(k)
if n.value is Node.no_value:
raise KeyError(k)
n.value = Node.no_value
while True:
if n.nodes or not n.parent:
break
del n.parent.nodes[n.key]
n = n.parent
def children(self, k):
"""Return a dict of the immediate children of the given key.
Example:
>>> t = Trie()
>>> t['foobaz'] = 'Here is a foobaz.'
>>> t['foobar'] = 'This is a foobar.'
>>> t.children('fooba')
{'r': 'This is a foobar.', 'z': 'Here is a foobaz.'}
"""
n = self._getnode(k)
return dict((k, n.nodes[k].value)
for k in n.nodes
if n.nodes[k].value is not Node.no_value)
def __iter__(self):
"""Yield the keys in order."""
for node in self.root.walk():
yield node.keypath
def iteritems(self):
"""Yield (key, value) pairs in order."""
for node in self.root.walk():
yield node.keypath, node.value
def itervalues(self):
"""Yield values in order."""
for node in self.root.walk():
yield node.value
if __name__ == '__main__':
import doctest
doctest.testmod()
import collections
import itertools
import unittest
from django.db import (
DEFAULT_DB_ALIAS,
connections,
transaction,
)
from django.test import testcases
from django.test.simple import DjangoTestSuiteRunner
from django.test.testcases import (
connections_support_transactions,
nop,
)
import dhain_trie
DATABASES = [DEFAULT_DB_ALIAS]
def iter_isfirst(iterable):
"""
Based on [1] to indicate first-run too.
[1] http://code.activestate.com/recipes/392015-finding-the-last-item-in-a-loop/
"""
it = iter(iterable)
yield it.next(), True
for x in it:
yield x, False
def stripPrefix(seq, prefix):
if not seq:
return None
if not prefix:
return None
it = iter(prefix)
trimmed = list(itertools.takewhile(lambda x: x == it.next(), seq))
del seq[0:len(trimmed)]
return trimmed
def bump(obj, attr, newval):
oldval = getattr(obj, attr, None)
setattr(obj, attr, newval)
return oldval
class Bunch(object):
"""
Based on the Python Cookbook 4.18
"""
def __init__(self, **kwargs):
self.update(**kwargs)
def update(self, **kwargs):
self.__dict__.update(kwargs)
class Queue(collections.deque):
def pop_iter(self):
try:
x = self.pop()
yield x
except IndexError:
raise StopIteration
class SuiteLikeTest(unittest.TestSuite):
@property
def firstTestClass(self):
for test in itertools.islice(self, 0, 1):
return test
@property
def fixtures(self):
kls = self.firstTestClass
if kls:
return getattr(kls, 'fixtures', None)
class PrefixSuite(SuiteLikeTest):
def __init__(self, *args, **kwargs):
self._prefixes = []
self._parentPrefixes = []
super(PrefixSuite, self).__init__(*args, **kwargs)
def addSuite(self, suite, prefixes):
suite._parentPrefixes = self._parentPrefixes + self._prefixes
suite._prefixes = prefixes
super(PrefixSuite, self).addTest(suite)
return suite, []
def __call__(self, *args, **kwargs):
fixtures = self.fixtures
trimmed = stripPrefix(fixtures, self._parentPrefixes)
savepoints = [(db, transaction.savepoint(using=db)) for db in DATABASES]
ret = super(PrefixSuite, self).__call__(*args, **kwargs)
for db, savepoint in savepoints:
transaction.savepoint_rollback(savepoint, using=db)
if fixtures and trimmed:
fixtures[0:] = trimmed + fixtures
return ret
class RootSuite(PrefixSuite):
def __call__(self, *args, **kwargs):
for db in DATABASES:
transaction.enter_transaction_management(using=db)
transaction.managed(True, using=db)
restore_transaction_methods = bump(testcases, 'restore_transaction_methods', nop)
disable_transaction_methods = bump(testcases, 'disable_transaction_methods', nop)
real_closes = [(c, bump(c, 'close', nop)) for c in connections.all()]
disable_transaction_methods()
# this is *not* a typo - skip PrefixSuite's overriden __call__
ret = super(PrefixSuite, self).__call__(*args, **kwargs)
bump(testcases, 'restore_transaction_methods', restore_transaction_methods)
bump(testcases, 'disable_transaction_methods', disable_transaction_methods)
restore_transaction_methods()
for db in DATABASES:
transaction.rollback(using=db)
transaction.leave_transaction_management(using=db)
for conn, real_close in real_closes:
bump(conn, 'close', real_close)
conn.close()
return ret
class CasesSuite(PrefixSuite):
def __dec_iter(self):
state = Bunch(savepoint_rollbacks=Queue())
for test, first in iter_isfirst(super(CasesSuite, self).__iter__()):
state.update(is_first=first)
setattr(test, '_state', state)
yield test
def __iter__(self):
if not getattr(self, '_needs_state', False):
return super(CasesSuite, self).__iter__()
return self.__dec_iter()
def run(self, *args, **kwargs):
self._needs_state = True
ret = super(CasesSuite, self).run(*args, **kwargs)
self._needs_state = False
return ret
class Tree(dhain_trie.Trie):
def addSuite(self, suite):
f = suite.fixtures
if f is None:
return False
try:
old_suite = self[f]
except (KeyError, dhain_trie.NeedMore):
self[f] = suite
else:
# we could have just stripped all the fixtures; by setting
# _parentPrefixes, we get the modify-restore functionality
# of PrefixSuite.
suite._parentPrefixes = old_suite.fixtures[:]
old_suite.addTest(suite)
return True
def asSuite(self):
root_suite = RootSuite()
nodes = [(root_suite, [], self.root)]
while nodes:
suite, prefixes, node = nodes.pop()
if node.key:
prefixes.append(node.key)
if not node.value is node.no_value:
suite, prefixes = suite.addSuite(node.value, prefixes)
nodes.extend((suite, prefixes[:], node) \
for node in node.nodes.values())
return root_suite
class TestSuiteRunner(DjangoTestSuiteRunner):
def run_suite(self, suite, **kwargs):
if not connections_support_transactions():
return super(TestSuiteRunner, self).run_suite(suite, **kwargs)
# collate by test's class
kls_suites = {}
for test in suite:
kls_suites.setdefault(test.__class__, CasesSuite()).addTest(test)
tree = Tree()
excluded = [suite for suite in kls_suites.values() \
if not tree.addSuite(suite)]
suite = tree.asSuite()
suite.addTests(excluded)
return super(TestSuiteRunner, self).run_suite(suite, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment