Skip to content

Instantly share code, notes, and snippets.

@eltjpm
Created July 24, 2013 23:13
Show Gist options
  • Save eltjpm/6075487 to your computer and use it in GitHub Desktop.
Save eltjpm/6075487 to your computer and use it in GitHub Desktop.
Index: athena/src/numba/numba/testing/test_support.py
===================================================================
--- athena/src/numba/numba/testing/test_support.py (revision 83368)
+++ athena/src/numba/numba/testing/test_support.py (working copy)
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
+import itertools
import os
import sys
import types
import unittest
-import functools
try:
from nose.tools import nottest
except ImportError:
@@ -159,34 +159,48 @@
# Test Parametrization
#------------------------------------------------------------------------
-def parametrize(*parameters):
+def parametrize(*parameters, **named_parameters):
"""
@parametrize('foo', 'bar')
def test_func(foo_or_bar):
print foo_or_bar # prints 'foo' or 'bar'
+ or
+
+ @parametrize(x=['foo', 'bar'], y=['baz', 'quux'])
+ def test_func(x, y):
+ print x, y # prints all combinations
+
Generates a unittest TestCase in the function's global scope named
'test_func_testcase' with parametrized test methods.
':return: The original function
"""
+ if parameters and named_parameters:
+ raise TypeError('Cannot specify both parameters and named_parameters')
+
def decorator(func):
class TestCase(unittest.TestCase):
pass
TestCase.__name__ = func.__name__
+ names = named_parameters.keys()
+ values = parameters or itertools.product(*named_parameters.values())
- for i, parameter in enumerate(parameters):
+ for i, parameter in enumerate(values):
name = 'test_%s_%d' % (func.__name__, i)
- def testfunc(self, parameter=parameter):
- return func(parameter)
+ if names:
+ def testfunc(self, parameter=parameter):
+ return func(**dict(zip(names, parameter)))
+ else:
+ def testfunc(self, parameter=parameter):
+ return func(parameter)
testfunc.__name__ = name
if func.__doc__:
testfunc.__doc__ = func.__doc__.replace(func.__name__, name)
- # func.func_globals[name] = unittest.FunctionTestCase(testfunc)
setattr(TestCase, name, testfunc)
@@ -194,3 +208,4 @@
return func
return decorator
+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment