Skip to content

Instantly share code, notes, and snippets.

@mfazekas
Last active June 10, 2016 10:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mfazekas/1710455 to your computer and use it in GitHub Desktop.
Save mfazekas/1710455 to your computer and use it in GitHub Desktop.
Parametric tests for nose/python unittest
#!/usr/bin/python
""" Implement parametric testMethods and testCases for nose and unittest
usage for parametric test methods:
@parametric
class MyTest(unittest.TestCase):
@parametric(foo=[1,2],bar=[3,4])
def testWithParams(self,foo,bar):
self.assertLess(foo,bar)
def testNormal(self):
self.assertEqual('foo','foo')
@parametric(foo=[1,2],bar=[3,4])
class MyTest(unittest.TestCase):
def testWithParams(self,foo,bar):
self.assertLess(foo,bar)
def testWithParams2(self,foo,bar):
self.assertLess(-bar,-foo)
License: Copyright (c) 2008 Miklos Fazekas http://opensource.org/licenses/MIT
"""
import unittest
import inspect
import os
import re
def _gen_params(basename,params):
''' returns all variations of parameters and basename '''
params = params.copy()
if len(params.keys()) > 0:
k = params.keys()[0]
vars = params.pop(k)
for vi in vars:
for (n,p) in _gen_params(basename,params):
p[k] = vi
yield (n+'_%s_%s'%(k,vi),p)
else:
yield (basename,{})
class GenParamsTestCase(unittest.TestCase):
def testAllVariations(self):
s = list(_gen_params('foo',{'p1':[3,4],'p2':[1,2]}))
rows = [ ('foo_p1_3_p2_1',{'p1':3,'p2':1}),
('foo_p1_3_p2_2',{'p1':3,'p2':2}),
('foo_p1_4_p2_1',{'p1':4,'p2':1}),
('foo_p1_4_p2_2',{'p1':4,'p2':2})]
rows.sort()
s.sort()
self.assertEqual(rows,s)
class ParametricMethod(object):
''' this class is used internally for representing unexpanded parameteric calls '''
def __init__(self,fun,params):
super(ParametricMethod,self).__init__()
self.fun = fun
self.params = params
def _gen_func(self,k,v):
def f(*args):
self.fun(*args,**v)
f.__name__=k
return f
def methods(self,basename):
for (k,v) in _gen_params(basename,self.params):
yield (k,self._gen_func(k,v))
def _set_all_methods_to_parametric(original_class,kwargs):
def is_test_name(name):
return _testMatch.match(name)
def is_test_method(name,value):
return inspect.isroutine(value) and callable(value) and is_test_name(name)
tomod = []
for k in original_class.__dict__:
v = original_class.__dict__[k]
if is_test_method(k,v):
tomod.append((k,ParametricMethod(v,kwargs)))
for (k,v) in tomod:
setattr(original_class,k,v)
def setattr_impl(k,v):
setattr(original_class,k,v)
_expand_parametric_methods(original_class.__dict__,setattr_impl)
return original_class
def _expand_parametric_methods(class_dict,setattr):
""" This function expand all parametric methods into actual methods """
def is_param_method(name,value):
return isinstance(value,ParametricMethod)
tomod = []
for k in class_dict:
v = class_dict[k]
if is_param_method(k,v):
tomod.append((k,v))
for (k,v) in tomod:
for (m,f) in v.methods(k):
setattr(m,f)
def _parametric_func(func,kwargs):
return ParametricMethod(func,kwargs)
def _parametric_klass(klass,kwargs):
return _set_all_methods_to_parametric(klass,kwargs)
class ParametricTestCaseMetaClass(type):
def __new__(meta, classname, bases, classDict):
def setattr_impl(k,v):
classDict[k]=v
_expand_parametric_methods(classDict,setattr_impl)
res = type.__new__(meta, classname, bases, classDict)
return res
def parametric(*args,**kwargs):
if len(args) > 0:
if len(args) > 1:
raise ValueError("No more than 1 argument is expected for @parametric!")
if not inspect.isclass(args[0]):
raise ValueError("@parametric without args can be applied only to classes!")
orig_class = args[0]
def setattr_impl(k,v):
setattr(orig_class,k,v)
_expand_parametric_methods(orig_class.__dict__,setattr_impl)
return orig_class
def parametric_func_or_klass(func_or_klass):
if inspect.isclass(func_or_klass):
return _parametric_klass(func_or_klass,kwargs)
else:
return _parametric_func(func_or_klass,kwargs)
return parametric_func_or_klass
_testMatchPat = os.environ.get('NOSE_TESTMATCH',
r'(?:^|[\b_\.%s-])[Tt]est' % os.sep)
_testMatch = re.compile(_testMatchPat)
class ParametricTestCaseBase(object):
__metaclass__ = ParametricTestCaseMetaClass
@parametric(p1=[4,3])
class SubTest(unittest.TestCase):
def testWithParams(self,p1):
self.assertEqual(p1,p1)
class SubTest2(unittest.TestCase,ParametricTestCaseBase):
def testFoo(self):
self.assertEqual(2,2)
@parametric(p1=[4,5])
def testWithParams(self,p1):
self.assertEqual(p1,p1)
@parametric
class SubTest3(unittest.TestCase):
def testFoo(self):
self.assertEqual(2,2)
@parametric(p1=[4,5])
def testWithParams(self,p1):
self.assertEqual(p1+1,p1+1)
if __name__=="__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment