Skip to content

Instantly share code, notes, and snippets.

@textbook
Last active Dec 27, 2015
Embed
What would you like to do?
Collection of useful bits and pieces of Python code, plus a few simple test routines - now pylint compliant
"""Simple unit testing functionality."""
from __future__ import print_function
def _run_function_test(function, expected, args=None, kwargs=None):
"""Check whether function returns/raises expected with supplied args."""
if args is None:
args = tuple()
if kwargs is None:
kwargs = dict()
try:
error_test = isinstance(expected(), Exception)
except TypeError:
error_test = False
if error_test:
try:
function(*args, **kwargs) # pylint: disable=star-args
except expected:
return True
except Exception: # pylint: disable=broad-except
return False
else:
return False
else:
try:
outcome = function(*args, **kwargs) # pylint: disable=star-args
except Exception: # pylint: disable=broad-except
return False
else:
if isinstance(outcome, float):
return _compare_floats(outcome, expected)
else:
return outcome == expected
def _compare_floats(float1, float2, tolerance=0.0000001):
"""Compare float numbers to a specified tolerance."""
return abs(float1 - float2) < tolerance
def test_function(tests, function, verbose=True):
"""Run a dictionary of tests on the supplied function."""
outcomes = {test: _run_function_test(function, *tests[test]) # pylint: disable=star-args
for test in tests}
template = "Test {0:d} {1:s}"
if verbose:
print("Testing {}()".format(function.__name__))
for name in sorted(tests.keys()):
print(template.format(name,
"passed" if outcomes[name] else "failed"))
return all(outcomes.values())
# Note: no automated test coverage for print_outcome()
def print_outcome(success):
"""Print out the test outcome."""
print("Test routine {0:s}".format("passed" if success else "failed"))
if __name__ == "__main__":
# error handling
TEST_RAISE_NAMEERROR = lambda: x # pylint: disable=undefined-variable
assert _run_function_test(TEST_RAISE_NAMEERROR, NameError)
assert not _run_function_test(TEST_RAISE_NAMEERROR, ValueError)
# float handling
assert not 1 + 0.1 == 1.2 - 0.1
assert _compare_floats(1 + 0.1, 1.2 - 0.1, 0.01)
assert not _compare_floats(1 + 0.1, 2, 0.1)
TEST_RETURN_FLOAT = lambda: 1 + 0.1
assert _run_function_test(TEST_RETURN_FLOAT, 1.2 - 0.1)
assert not _run_function_test(TEST_RETURN_FLOAT, 3.0)
# non-float handling
TEST_RETURN_INT = lambda: 10
assert _run_function_test(TEST_RETURN_INT, 10)
assert not _run_function_test(TEST_RETURN_INT, 11)
# arguments
TEST_HANDLE_ARGS = lambda *args: sum(args)
assert _run_function_test(TEST_HANDLE_ARGS, 6, (1, 2, 3))
TEST_HANDLE_KWARGS = lambda **kwargs: sum(kwargs.values())
assert _run_function_test(TEST_HANDLE_KWARGS, 3, kwargs={'a': 1, 'b': 2})
# test_function
TEST_FUNCTION = lambda *args, **kwargs: sum(kwargs.values()) + sum(args)
TESTS = {1: (10, (1, 2), {'a': 3, 'b': 4}),
2: (3, (1, 2)),
3: (7, None, {'a': 3, 'b': 4}),
4: (TypeError, {'a': 3, 'b': '4'})}
assert test_function(TESTS, TEST_FUNCTION, False)
# success
print("Testing works.")
"""A collection of useful, reusable functions."""
from __future__ import print_function
from itertools import combinations
from math import sqrt
from operator import mul
from os import walk, sep
from string import punctuation
from sys import version_info
# Deal with major differences between 2.x and 3.x
# pylint: disable=redefined-builtin
if version_info.major < 3:
input = raw_input # pylint: disable=invalid-name
else:
from functools import reduce
basestring = str # pylint: disable=invalid-name
def mean(seq, method="arithmetic"):
"""Returns the mean of values in seq, using the specified method.
Arguments:
seq (list, tuple, gen): the values.
method (str, optional): either "arithmetic", "harmonic",
"geometric" or "quadratic" (defaults to "arithmetic").
A ValueError is raised if an invalid method is supplied
"""
seq = list(seq)
if not all(isinstance(x, (float, int)) for x in seq):
raise TypeError("All items in xn must be numerical")
if method == "arithmetic":
return sum(seq) / float(len(seq))
elif method == "harmonic":
return float(len(seq)) / sum((1.0 / x) for x in seq)
elif method == "geometric":
return product(seq) ** (1.0 / len(seq))
elif method == "quadratic":
return (sum(x ** 2 for x in seq) / float(len(seq))) ** 0.5
else:
raise ValueError("'{0!s}' is not a valid method".format(method))
def std_dev(seq):
"""Return the standard deviation of the values in seq."""
return sqrt(mean(x ** 2 for x in seq) - (mean(seq) ** 2))
def product(seq):
"""Returns the cumulative product of all terms in a sequence"""
return reduce(mul, seq) if seq else 1
def factorial(num):
"""Calculate num!, the factorial of num."""
return product(range(1, num + 1))
### not needed - use itertools.combinations(seq, 2)
##def pairs(seq):
## """Recursively return a list of tuples of pairs of items in a sequence"""
## if len(seq) == 2:
## return [(seq[0], seq[1])]
## return [(seq[0], seq[i])
## for i in range(1, len(seq))] + pairs(seq[1:])
### not needed - use "{:,}".format(num)
##def format_number(num, sep=",", block=3):
## if block <= 0:
## raise ValueError("block must be greater than zero")
## if num < (10 ** block):
## return str(num)
## return sep.join([format_number((num / (10 ** block)), sep, block),
## str(num)[-block:]])
def evaluate_polynomial(poly, num):
"""Evaluates the polynomial poly=[a0, a1, ..., an] at num."""
return float(sum(b * (num ** a) for a, b in enumerate(poly)))
def percentile(seq, perc, sort=sorted):
"""Return the perc'th percentile of a list of items seq."""
return sort(seq)[int(round((perc * len(seq)) + 0.5)) - 1]
def is_palindrome(text, ignore=" "+punctuation):
"""Returns whether or not the text supplied is palindromic.
ignore (optional) is an iterable of characters to exclude from
the assessment. By default, ignores the characters in
string.punctuation:
!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~
and spaces.
"""
ignore = set(ignore)
text = [char.lower() for char in text if char not in ignore]
return text == text[::-1]
# Note: no automated test coverage for get_filelist()
def get_filelist(root, extensions=None):
"""Return a list of files (path and name) within a supplied root directory.
To filter by extension(s), provide a list of strings, e.g.
get_filelist(root, ["zip", "csv"])
"""
if extensions is None:
extensions = []
extensions = set(extensions)
return reduce(lambda x, y: x+y,
[[sep.join([item[0], name]) for name in item[2]
if (len(extensions) == 0 or
name.split(".")[-1] in extensions)]
for item in walk(root)])
def flatten(input_, output=None):
"""Flatten the input to a single list of non-iterables and strings."""
if output is None:
output = []
if isinstance(input_, basestring):
output.append(input_)
else:
for item in input_:
try:
flatten(item, output)
except TypeError:
output.append(item)
return output
def flatten_dict(dict_, output=None):
"""Copy items at all levels of a nested dictionary to the outer level.
Note: If a key appears at multiple levels, the value from the outermost
level will appear in the output dictionary.
"""
if output is None:
output = {}
for key in dict_:
if isinstance(dict_[key], dict):
flatten_dict(dict_[key], output)
output[key] = dict_[key]
for key in dict_:
if not isinstance(dict_[key], dict):
output[key] = dict_[key]
return output
def luhn_check_digit(seq):
"""Use Luhn algorithm to get check digit for a sequence of str/int."""
nums = [int(c) for c in seq if isinstance(c, int) or c.isdigit()]
if len(nums) == 0:
raise ValueError
nums[-1::-2] = ((2 * i) - (9 if i > 4 else 0) for i in nums[-1::-2])
return (sum(nums) * 9) % 10
def check_luhn(string):
"""Use Luhn algorithm to validate a sequence of str/int."""
nums = [int(char) for char in string if isinstance(char, (float, int))
or char.isdigit()]
if len(nums) == 0:
raise ValueError
return nums[-1] == luhn_check_digit(nums[:-1])
def roman_to_arabic(roman):
"""Convert roman numerals (mixed case) to an integer."""
mapping = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
try:
numerals = [mapping[char.upper()] for char in roman]
except KeyError:
raise ValueError("Not a valid Roman numeral string.")
max_ = arabic = 0
for digit in numerals[::-1]:
max_ = max(max_, digit)
if digit == max_:
arabic += digit
else:
arabic -= digit
return arabic
def coprime(values, pairwise=False):
"""Whether the set of values share any common divisors > 1."""
if pairwise:
return all(coprime((a, b), False) for a, b in combinations(values, 2))
for num in range(2, min(values) // 2):
if not any(val % num for val in values):
return False
return True
# Note: no automated test coverage for sanitised_input()
def sanitised_input(prompt, type_=None, min_=None,
max_=None, len_=None, range_=None):
"""Take user input, applying sanitisation."""
# pylint: disable=too-many-arguments, too-many-branches
if min_ is not None and max_ is not None and max_ < min_:
raise ValueError("min_ must be less than or equal to max_.")
while True:
input_ = input(prompt)
if type_ is not None:
try:
input_ = type_(input_)
except ValueError:
print("Input type must be {0}.".format(type_.__name__))
continue
if max_ is not None and input_ > max_:
print("Input must be less than or equal to {0}.".format(max_))
elif min_ is not None and input_ < min_:
print("Input must be greater than or equal to {0}.".format(min_))
elif len_ is not None and len(input_) != len_:
print("Input must have length {0}.".format(len_))
elif range_ is not None and input_ not in range_:
if type(range_) == range:
template = "Input must be between {0.start} and {0.stop}."
print(template.format(range_))
else:
template = "Input must be {0}."
if len(range_) == 1:
print(template.format(*range_)) # pylint: disable=star-args
else:
options = ", ".join(map(str, range_[:-1])) # pylint: disable=bad-builtin
options = " or ".join((options, str(range_[-1])))
print(template.format(options))
else:
return input_
if __name__ == "__main__":
# pylint: disable=invalid-name
from testing import test_function, print_outcome
success = True
func = mean
TESTS = {1: (ValueError, ([1], "ValueError")),
2: (2, ([1, 2, 3], "arithmetic")),
3: (12.0 / 7, ([1, 2, 4], "harmonic")),
4: (4, ([2, 8], "geometric")),
5: ((14.0 / 3) ** 0.5, ([1, 2, 3], "quadratic")),
6: (TypeError, ([1, 2, "a"], "TypeError"))}
success = test_function(TESTS, func) and success
func = product
TESTS = {1:(6, ([1, 2, 3],)),
2:("aa", ([1, 2, "a"],)),
3:(TypeError, ([1, "a", "b"],)),
4:(1, ([],))}
success = test_function(TESTS, func) and success
func = factorial
TESTS = {1: (362880, (9,)),
2: (TypeError, ("a",))}
success = test_function(TESTS, func) and success
func = std_dev
TESTS = {1: (sqrt(2.0 / 3), ([1, 2, 3], ))}
success = test_function(TESTS, func) and success
f = evaluate_polynomial
TESTS = {1: (0, ([], 1)),
2: (1, ([1], 1)),
3: (3, ([1, 2], 1)),
4: (6, ([1, 2, 3], 1)),
5: (2, ([1, -2, 3], 1))}
success = test_function(TESTS, func) and success
func = percentile
TEST_FUNC = lambda xn: sorted(xn, reverse=True)
TESTS = {1: (20, ([15, 20, 35, 40, 50], 0.3)),
2: (20, ([20, 40, 15, 50, 35], 0.3)),
3: (40, ([20, 40, 15, 50, 35], 0.3, TEST_FUNC))}
success = test_function(TESTS, func) and success
func = is_palindrome
TESTS = {1: (True, ("anna",)),
2: (True, ("bob",)),
3: (False, ("colin",)),
4: (True, ("madam i'm adam",)),
5: (False, ("madam i'm eve",)),
6: (False, ("madam i'm adam", (" "))),
7: (True, ("madam i'm adam", (" '"))),
8: (True, ("a man, a plan, a canal: Panama",))}
success = test_function(TESTS, func) and success
func = flatten
TESTS = {1: ([1, 2, 3, 4, 5, 6, 7, 8], ([1, [2, [3, [4, 5], 6], 7], 8],)),
2: ([1, 'foo', 'bar', 2, 3], ([1, "foo", ["bar", 2], 3],)),
3: ([1, 2, 3, 4], ([1, (2, 3), {4: 5}],)),
4: (['hello'], ('hello',)),
5: (TypeError, (1,))}
success = test_function(TESTS, func) and success
func = flatten_dict
TESTS = {1: ({}, ({},)),
2: ({'a': {'b':2, 'c':3}, 'd':4, 'b':2, 'c':3},
({'a': {'b':2, 'c':3}, 'd':4},)),
3: ({'a': {'b':2, 'c':3}, 'b':4, 'c':3},
({'a': {'b':2, 'c':3}, 'b':4},))}
success = test_function(TESTS, func) and success
func = luhn_check_digit
TESTS = {1: (ValueError, ("hello",)),
2: (3, ("7992739871",))}
success = test_function(TESTS, func) and success
func = check_luhn
TESTS = {1: (ValueError, ("hello",)),
2: (True, ("79927398713",)),
3: (False, ("79927398716",))}
success = test_function(TESTS, func) and success
func = roman_to_arabic
TESTS = {1: (18, ("XIIX",)),
2: (1954, ("MCMLIV",)),
3: (18, ("xviii",)),
4: (0, ("",)),
5: (ValueError, ("foo",))}
success = test_function(TESTS, func) and success
func = coprime
TESTS = {1: (True, ((4, 9),)),
2: (True, ((4, 9), True)),
3: (False, ((6, 10),)),
4: (True, ((6, 10, 15),)),
5: (False, ((6, 10, 15), True))}
success = test_function(TESTS, func) and success
print_outcome(success)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment