Last active
June 20, 2018 17:45
-
-
Save Daenyth/e4b37e31c20994b4090178d1f98742f3 to your computer and use it in GitHub Desktop.
functional.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Functional programming utilities for python | |
""" | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
# pylint: disable=redefined-builtin | |
from collections import defaultdict | |
from copy import copy | |
from builtins import filter, map | |
from future.utils import PY2, iteritems | |
from typing import Callable, Dict, Generic, Iterable, Iterator, List, TypeVar, Set, Optional # pylint: disable=unused-import | |
if PY2: | |
from itertools import ifilterfalse as filterfalse | |
else: | |
from itertools import filterfalse # pylint: disable=no-name-in-module | |
from itertools import tee | |
from functools import reduce # pylint: disable=redefined-builtin | |
import operator | |
A = TypeVar('A') | |
B = TypeVar('B') | |
C = TypeVar('C') | |
def distinct_by(key, iterable): | |
# type: (Callable[[A], B], Iterable[A]) -> Iterator[A] | |
"""Remove duplicate elements from an iterable, defined by some key""" | |
seen = set() # type: Set[B] | |
for item in iterable: | |
id = key(item) | |
if id in seen: | |
continue | |
else: | |
seen.add(id) | |
yield item | |
def group_by(func, iterable): | |
# type: (Callable[[A], B], Iterable[A]) -> Dict[B, List[A]] | |
"""Group an iterable by an aspect of each item. | |
:param func: A function accepting one item from the iterable and returning the key to group by | |
:param iterable: A collection of items to group | |
:return: A dictionary with keys of function results mapping to a list of elements which generated that result. | |
""" | |
results = defaultdict(list) # type: Dict[B, List[A]] | |
for item in iterable: | |
results[func(item)].append(item) | |
return results | |
def group_map(key, mapper, iterable): | |
# type: (Callable[[A], B], Callable[[A], C], Iterable[A]) -> Dict[B, List[C]] | |
"""Group an iterable by an aspect of each item, extracting values at each step | |
:param key: The function from iterable item to key | |
:param mapper: The function from iterable item to value | |
:param iterable: The items to iterate over | |
:return: A dict with the keys and values derived from iterable using key and mapper | |
""" | |
results = defaultdict(list) # type: Dict[B, List[C]] | |
for item in iterable: | |
results[key(item)].append(mapper(item)) | |
return results | |
def filter_keys(predicate, mapping): | |
""" | |
Produce a new mapping from mapping where predicate(key) is true for all keys. | |
""" | |
return {k: v for k, v in iteritems(mapping) if predicate(k)} | |
def filter_dict(predicate, mapping): | |
""" | |
Produce a new mapping from mapping where predicate(key, value) is true for all items | |
""" | |
return {k: v for k, v in iteritems(mapping) if predicate(k, v)} | |
def find(predicate, iterable): | |
"""Return the first element of `iterable` where `predicate(element) == True`, or None""" | |
for item in iterable: | |
if predicate(item): | |
return item | |
return None | |
# pylint: disable=invalid-name | |
def partition(predicate, iterable): | |
""" | |
Use a predicate to partition entries into true entries and false entries | |
# partition(is_odd, range(10)) --> 1 3 5 7 9 and 0 2 4 6 8 | |
""" | |
# Adapted from itertools recipes page | |
t1, t2 = tee(iterable) | |
return filter(predicate, t1), filterfalse(predicate, t2) | |
def flatten(iterable_of_iterables): | |
""" | |
Flattens one level of a nested iterable of iterables. | |
:return: A generator | |
""" | |
for iterable in iterable_of_iterables: | |
for item in iterable: | |
yield item | |
def concatmap(func, iterable): | |
# type: (Callable[[A], Iterable[B]], Iterable[A]) -> Iterator[B] | |
""" | |
Maps a function that returns an iterable over another iterable, flattening the result to one layer. | |
aka flatMap specialized to iterables | |
:return: An iterator | |
""" | |
for item in iterable: | |
for result in func(item): | |
yield result | |
def unsafe_head(iterable, default=None): | |
# type: (Iterable[A], Optional[A]) -> Optional[A] | |
"""Return the first element of iterable or the default. | |
NB: If input is an iterator, it will be advanced | |
""" | |
it = iter(iterable) | |
return next(it, default) | |
class Monoid(Generic[A]): | |
""" | |
A structure which can be associatively reduced and has an identity element | |
See https://en.wikipedia.org/wiki/Monoid | |
""" | |
def __init__(self, op, zero): | |
# type: (Callable[[A, A], A], Callable[[], A]) -> None | |
self.op = op | |
self.zero = zero | |
def monoid_sum(self, iterable): | |
# type: (Iterable[A]) -> A | |
"""Same as the free monoid_sum function, but can be overridden to be more efficient""" | |
return monoid_sum(self, iterable) | |
def foldmap(self, func, iterable): | |
# type: (Callable[[B], A], Iterable[B]) -> A | |
"""Same as the free foldmap function, but can be overridden to be more efficient""" | |
return self.monoid_sum(map(func, iterable)) | |
SetMonoid = Monoid(op=set.union, zero=set) # type: Monoid[set] | |
ListMonoid = Monoid(op=operator.add, zero=list) # type: Monoid[list] | |
class DictMergeMonoid(Monoid): | |
"""Monoid over dictionaries that merges values with a value monoid""" | |
def __init__(self, value_monoid): # pylint: disable=super-init-not-called | |
self.value_monoid = value_monoid | |
@staticmethod | |
def zero(): | |
return {} | |
def op(self, x, y): | |
result = copy(x) | |
self._unsafe_op(result, y) | |
return result | |
def _unsafe_op(self, target, other): | |
for key, value in other.items(): | |
if key not in target: | |
target[key] = value | |
else: | |
target[key] = self.value_monoid.op(target[key], value) | |
def monoid_sum(self, iterable): | |
it = iter(iterable) | |
try: | |
first = next(it) | |
except StopIteration: | |
return self.zero() | |
result = copy(first) | |
for other in it: | |
self._unsafe_op(result, other) | |
return result | |
def foldmap(func, monoid, iterable): | |
# type: (Callable[[A], B], Monoid[B], Iterable[A]) -> B | |
"""Produces the monoid sum of an iterable when mapped with func""" | |
return monoid_sum(monoid, map(func, iterable)) | |
def monoid_sum(monoid, iterable): | |
# type: (Monoid[A], Iterable[A]) -> A | |
""" | |
Produces the monoid sum of an iterable | |
Equivalent to foldmap(lambda x: x, monoid, iterable) | |
Sometimes called `suml` | |
""" | |
return reduce(monoid.op, iterable, monoid.zero()) | |
def tupled(func): | |
"""Convert a function accepting N arguments to a function accepting one N-tuple""" | |
return lambda args: func(*args) | |
def untupled(func): | |
"""Convert a function accepting one N-tuple argument to a function accepting N arguments""" | |
return lambda *args: func(args) | |
def kwstarred(func): | |
"""Convert a function accepting all keyword arguments to a function accepting one dict""" | |
return lambda d: func(**d) | |
def const(value): | |
# type: (A) -> Callable[..., A] | |
"""Create a function which always returns the same value when evaluated.""" | |
return lambda *_, **__: value | |
def identity(obj): | |
# type: (A) -> A | |
"""Return the provided value.""" | |
return obj | |
def compose(f, g): | |
# type: (Callable[[B], C], Callable[[A], B]) -> Callable[[A], C] | |
return lambda x: f(g(x)) | |
def compose_all(*fns): | |
# type: (Iterable[Callable[[A], A]]) -> Callable[[A], A] | |
"""Compose multiple callables taking one argument into a single function.""" | |
return reduce(compose, fns, identity) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import operator as op | |
import unittest | |
from collections import namedtuple | |
from itertools import tee | |
import hypothesis.strategies as st | |
from builtins import range | |
from future.utils import lmap, lrange, text_type | |
from hypothesis import HealthCheck, given, settings | |
from functional import ( | |
DictMergeMonoid, | |
ListMonoid, | |
Monoid, | |
SetMonoid, | |
filter_dict, | |
filter_keys, | |
find, | |
foldmap, | |
group_by, | |
kwstarred, | |
monoid_sum, | |
partition, | |
tupled, | |
untupled, | |
unsafe_head, | |
) | |
# pylint: disable=invalid-name | |
class FunctionalTest(unittest.TestCase): | |
def setUp(self): | |
keys = ['myField', 'foundField'] | |
val1 = ['foo', False] | |
val2 = ['bar', True] | |
val3 = ['baz', False] | |
dict_list = [] | |
dict_list.append(dict(zip(keys, val1))) | |
dict_list.append(dict(zip(keys, val2))) | |
dict_list.append(dict(zip(keys, val3))) | |
self.dict_list = dict_list | |
def test_find__item_found(self): | |
result = find(lambda x: x['myField'] == 'bar', self.dict_list) | |
self.assertEqual(result['myField'], 'bar') | |
self.assertEqual(result['foundField'], True) | |
def test_find__item_not_found(self): | |
result = find(lambda x: x['myField'] == 'NOT_IN_LIST', self.dict_list) | |
self.assertEqual(result, None) | |
def test_group_by(self): | |
self.assertEqual({}, group_by(self.fail, [])) | |
self.assertEqual({True: lrange(10)}, group_by(lambda x: True, list(range(10)))) | |
self.assertEqual({i: [i] for i in lrange(10)}, group_by(lambda x: x, lrange(10))) | |
even = lambda x: x % 2 == 0 | |
odds_evens = group_by(even, lrange(10)) | |
self.assertTrue(all(even(x) for x in odds_evens[True])) | |
self.assertFalse(any(even(x) for x in odds_evens[False])) | |
Point = namedtuple('Point', 'x y') # pylint: disable=invalid-name | |
points = [Point(x, y) for x in lrange(10) for y in lrange(x)] | |
expected = {x: [Point(x, y) for y in lrange(x)] for x in lrange(10) if x != 0} | |
self.assertEqual(expected, group_by(lambda p: p.x, points)) | |
def test_partition(self): | |
is_odd = lambda x: x % 2 != 0 | |
odds, evens = partition(is_odd, range(500)) | |
self.assertTrue(all(map(is_odd, odds))) | |
self.assertFalse(any(map(is_odd, evens))) | |
def test_partition__empty(self): | |
p = lambda x: True | |
result = lmap(list, partition(p, [])) | |
self.assertEqual([[], []], result) | |
@given(st.lists(st.text())) | |
def test_foldmap(self, texts): | |
IntAddMonoid = Monoid(op=int.__add__, zero=lambda: 0) | |
self.assertEqual(sum(len(t) for t in texts), foldmap(len, IntAddMonoid, texts)) | |
IntMulMonoid = Monoid(op=op.mul, zero=lambda: 1) | |
tally = 1 | |
for t in texts: | |
tally *= len(t) | |
self.assertEqual(tally, foldmap(len, IntMulMonoid, texts)) | |
StrCatMonoid = Monoid(op=text_type.__add__, zero=lambda: u"") | |
self.assertEqual("".join(texts), foldmap(lambda x: x, StrCatMonoid, texts)) | |
@given(st.dictionaries(st.integers(), st.none())) | |
def test_filter_keys(mapping): | |
p = lambda k: k % 2 == 0 | |
new = filter_keys(p, mapping) | |
assert new is not mapping | |
assert all(p(k) for k in new.keys()) | |
assert all(k not in new for k in mapping if not p(k)) | |
@given(st.dictionaries(st.integers(), st.integers(min_value=1))) | |
def test_filter_dict(mapping): | |
p = lambda k, v: k % v == 0 | |
new = filter_dict(p, mapping) | |
assert new is not mapping | |
assert all(p(k, v) for k, v in new.items()) | |
assert all(k not in new for k, v in mapping.items() if not p(k, v)) | |
@given(st.lists(st.integers(), max_size=2) | st.lists(st.integers(), max_size=2).map(iter), st.none() | st.integers()) | |
def test_head(iterable, default): | |
iter1, iter2 = tee(iterable) | |
iter_copy = list(iter2) | |
the_head = unsafe_head(iter1, default) | |
rest = iter_copy[1:] | |
if len(iter_copy) == 0: | |
assert the_head == default | |
else: | |
assert [the_head] + rest == iter_copy # head + rest == all | |
class MonoidLawTest(object): | |
"""Mixin giving monoid law tests.""" | |
monoid = None | |
strategy = None | |
def test_identity(self): | |
@given(self.strategy) | |
def _test_identity(value): | |
m = self.monoid | |
assert m.op(value, m.zero()) == value # Right identity | |
assert m.op(m.zero(), value) == value # Left identity | |
_test_identity() # pylint: disable=no-value-for-parameter | |
def test_associativity(self): # pylint: disable=no-value-for-parameter | |
@given(self.strategy, self.strategy, self.strategy) | |
@settings(suppress_health_check=[HealthCheck.too_slow]) | |
def _test_associativity(value1, value2, value3): | |
m = self.monoid | |
assert m.op(m.op(value1, value2), value3) == m.op(value1, m.op(value2, value3)) | |
_test_associativity() # pylint: disable=no-value-for-parameter | |
def test_monoid_sum(self): | |
"""Test that overridden efficient m.monoid_sum(iterable) is equivalent to monoid_sum(m, iterable)""" | |
@given(st.lists(self.strategy)) | |
@settings(suppress_health_check=[HealthCheck.too_slow]) | |
def _test_msum(values): | |
assert monoid_sum(self.monoid, values) == self.monoid.monoid_sum(values) | |
_test_msum() # pylint: disable=no-value-for-parameter | |
class TestSetMonoid(MonoidLawTest, unittest.TestCase): | |
monoid = SetMonoid | |
strategy = st.sets(st.integers()) | |
class TestDictMergeMonoid(MonoidLawTest, unittest.TestCase): | |
monoid = DictMergeMonoid(SetMonoid) # SetMonoid for convenience; it could be any valid monoid | |
strategy = st.dictionaries(st.integers(), st.sets(st.integers())) | |
class TestListMonoid(MonoidLawTest, unittest.TestCase): | |
monoid = ListMonoid | |
strategy = st.lists(st.integers()) | |
def test_tupled(): | |
assert tupled(op.add)((1, 2)) == op.add(1, 2) | |
def test_untupled(): | |
f = lambda t: (t[0], t[1]) | |
assert untupled(f)(1, 2) == f((1, 2)) | |
def test_tuple_inverts_untupled(): | |
f = lambda t: (t[0], t[1]) | |
assert tupled(untupled(f))((1, 2)) == f((1, 2)) | |
def test_untupled_inverts_tupled(): | |
assert untupled(tupled(op.add))(1, 2) == op.add(1, 2) | |
def test_kwstarred(): | |
def f(a, b): | |
return a, b | |
assert kwstarred(f)({'a': 1, 'b': 2}) == f(a=1, b=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment