Skip to content

Instantly share code, notes, and snippets.

@Daenyth
Last active June 20, 2018 17:45
Show Gist options
  • Save Daenyth/e4b37e31c20994b4090178d1f98742f3 to your computer and use it in GitHub Desktop.
Save Daenyth/e4b37e31c20994b4090178d1f98742f3 to your computer and use it in GitHub Desktop.
functional.py
"""
Functional programming utilities for python
"""
from __future__ import absolute_import, division, print_function, unicode_literals
# pylint: disable=redefined-builtin
from collections import defaultdict
from copy import copy
from builtins import filter, map
from future.utils import PY2, iteritems
from typing import Callable, Dict, Generic, Iterable, Iterator, List, TypeVar, Set, Optional # pylint: disable=unused-import
if PY2:
from itertools import ifilterfalse as filterfalse
else:
from itertools import filterfalse # pylint: disable=no-name-in-module
from itertools import tee
from functools import reduce # pylint: disable=redefined-builtin
import operator
A = TypeVar('A')
B = TypeVar('B')
C = TypeVar('C')
def distinct_by(key, iterable):
# type: (Callable[[A], B], Iterable[A]) -> Iterator[A]
"""Remove duplicate elements from an iterable, defined by some key"""
seen = set() # type: Set[B]
for item in iterable:
id = key(item)
if id in seen:
continue
else:
seen.add(id)
yield item
def group_by(func, iterable):
# type: (Callable[[A], B], Iterable[A]) -> Dict[B, List[A]]
"""Group an iterable by an aspect of each item.
:param func: A function accepting one item from the iterable and returning the key to group by
:param iterable: A collection of items to group
:return: A dictionary with keys of function results mapping to a list of elements which generated that result.
"""
results = defaultdict(list) # type: Dict[B, List[A]]
for item in iterable:
results[func(item)].append(item)
return results
def group_map(key, mapper, iterable):
# type: (Callable[[A], B], Callable[[A], C], Iterable[A]) -> Dict[B, List[C]]
"""Group an iterable by an aspect of each item, extracting values at each step
:param key: The function from iterable item to key
:param mapper: The function from iterable item to value
:param iterable: The items to iterate over
:return: A dict with the keys and values derived from iterable using key and mapper
"""
results = defaultdict(list) # type: Dict[B, List[C]]
for item in iterable:
results[key(item)].append(mapper(item))
return results
def filter_keys(predicate, mapping):
"""
Produce a new mapping from mapping where predicate(key) is true for all keys.
"""
return {k: v for k, v in iteritems(mapping) if predicate(k)}
def filter_dict(predicate, mapping):
"""
Produce a new mapping from mapping where predicate(key, value) is true for all items
"""
return {k: v for k, v in iteritems(mapping) if predicate(k, v)}
def find(predicate, iterable):
"""Return the first element of `iterable` where `predicate(element) == True`, or None"""
for item in iterable:
if predicate(item):
return item
return None
# pylint: disable=invalid-name
def partition(predicate, iterable):
"""
Use a predicate to partition entries into true entries and false entries
# partition(is_odd, range(10)) --> 1 3 5 7 9 and 0 2 4 6 8
"""
# Adapted from itertools recipes page
t1, t2 = tee(iterable)
return filter(predicate, t1), filterfalse(predicate, t2)
def flatten(iterable_of_iterables):
"""
Flattens one level of a nested iterable of iterables.
:return: A generator
"""
for iterable in iterable_of_iterables:
for item in iterable:
yield item
def concatmap(func, iterable):
# type: (Callable[[A], Iterable[B]], Iterable[A]) -> Iterator[B]
"""
Maps a function that returns an iterable over another iterable, flattening the result to one layer.
aka flatMap specialized to iterables
:return: An iterator
"""
for item in iterable:
for result in func(item):
yield result
def unsafe_head(iterable, default=None):
# type: (Iterable[A], Optional[A]) -> Optional[A]
"""Return the first element of iterable or the default.
NB: If input is an iterator, it will be advanced
"""
it = iter(iterable)
return next(it, default)
class Monoid(Generic[A]):
"""
A structure which can be associatively reduced and has an identity element
See https://en.wikipedia.org/wiki/Monoid
"""
def __init__(self, op, zero):
# type: (Callable[[A, A], A], Callable[[], A]) -> None
self.op = op
self.zero = zero
def monoid_sum(self, iterable):
# type: (Iterable[A]) -> A
"""Same as the free monoid_sum function, but can be overridden to be more efficient"""
return monoid_sum(self, iterable)
def foldmap(self, func, iterable):
# type: (Callable[[B], A], Iterable[B]) -> A
"""Same as the free foldmap function, but can be overridden to be more efficient"""
return self.monoid_sum(map(func, iterable))
SetMonoid = Monoid(op=set.union, zero=set) # type: Monoid[set]
ListMonoid = Monoid(op=operator.add, zero=list) # type: Monoid[list]
class DictMergeMonoid(Monoid):
"""Monoid over dictionaries that merges values with a value monoid"""
def __init__(self, value_monoid): # pylint: disable=super-init-not-called
self.value_monoid = value_monoid
@staticmethod
def zero():
return {}
def op(self, x, y):
result = copy(x)
self._unsafe_op(result, y)
return result
def _unsafe_op(self, target, other):
for key, value in other.items():
if key not in target:
target[key] = value
else:
target[key] = self.value_monoid.op(target[key], value)
def monoid_sum(self, iterable):
it = iter(iterable)
try:
first = next(it)
except StopIteration:
return self.zero()
result = copy(first)
for other in it:
self._unsafe_op(result, other)
return result
def foldmap(func, monoid, iterable):
# type: (Callable[[A], B], Monoid[B], Iterable[A]) -> B
"""Produces the monoid sum of an iterable when mapped with func"""
return monoid_sum(monoid, map(func, iterable))
def monoid_sum(monoid, iterable):
# type: (Monoid[A], Iterable[A]) -> A
"""
Produces the monoid sum of an iterable
Equivalent to foldmap(lambda x: x, monoid, iterable)
Sometimes called `suml`
"""
return reduce(monoid.op, iterable, monoid.zero())
def tupled(func):
"""Convert a function accepting N arguments to a function accepting one N-tuple"""
return lambda args: func(*args)
def untupled(func):
"""Convert a function accepting one N-tuple argument to a function accepting N arguments"""
return lambda *args: func(args)
def kwstarred(func):
"""Convert a function accepting all keyword arguments to a function accepting one dict"""
return lambda d: func(**d)
def const(value):
# type: (A) -> Callable[..., A]
"""Create a function which always returns the same value when evaluated."""
return lambda *_, **__: value
def identity(obj):
# type: (A) -> A
"""Return the provided value."""
return obj
def compose(f, g):
# type: (Callable[[B], C], Callable[[A], B]) -> Callable[[A], C]
return lambda x: f(g(x))
def compose_all(*fns):
# type: (Iterable[Callable[[A], A]]) -> Callable[[A], A]
"""Compose multiple callables taking one argument into a single function."""
return reduce(compose, fns, identity) # type: ignore
# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals
import operator as op
import unittest
from collections import namedtuple
from itertools import tee
import hypothesis.strategies as st
from builtins import range
from future.utils import lmap, lrange, text_type
from hypothesis import HealthCheck, given, settings
from functional import (
DictMergeMonoid,
ListMonoid,
Monoid,
SetMonoid,
filter_dict,
filter_keys,
find,
foldmap,
group_by,
kwstarred,
monoid_sum,
partition,
tupled,
untupled,
unsafe_head,
)
# pylint: disable=invalid-name
class FunctionalTest(unittest.TestCase):
def setUp(self):
keys = ['myField', 'foundField']
val1 = ['foo', False]
val2 = ['bar', True]
val3 = ['baz', False]
dict_list = []
dict_list.append(dict(zip(keys, val1)))
dict_list.append(dict(zip(keys, val2)))
dict_list.append(dict(zip(keys, val3)))
self.dict_list = dict_list
def test_find__item_found(self):
result = find(lambda x: x['myField'] == 'bar', self.dict_list)
self.assertEqual(result['myField'], 'bar')
self.assertEqual(result['foundField'], True)
def test_find__item_not_found(self):
result = find(lambda x: x['myField'] == 'NOT_IN_LIST', self.dict_list)
self.assertEqual(result, None)
def test_group_by(self):
self.assertEqual({}, group_by(self.fail, []))
self.assertEqual({True: lrange(10)}, group_by(lambda x: True, list(range(10))))
self.assertEqual({i: [i] for i in lrange(10)}, group_by(lambda x: x, lrange(10)))
even = lambda x: x % 2 == 0
odds_evens = group_by(even, lrange(10))
self.assertTrue(all(even(x) for x in odds_evens[True]))
self.assertFalse(any(even(x) for x in odds_evens[False]))
Point = namedtuple('Point', 'x y') # pylint: disable=invalid-name
points = [Point(x, y) for x in lrange(10) for y in lrange(x)]
expected = {x: [Point(x, y) for y in lrange(x)] for x in lrange(10) if x != 0}
self.assertEqual(expected, group_by(lambda p: p.x, points))
def test_partition(self):
is_odd = lambda x: x % 2 != 0
odds, evens = partition(is_odd, range(500))
self.assertTrue(all(map(is_odd, odds)))
self.assertFalse(any(map(is_odd, evens)))
def test_partition__empty(self):
p = lambda x: True
result = lmap(list, partition(p, []))
self.assertEqual([[], []], result)
@given(st.lists(st.text()))
def test_foldmap(self, texts):
IntAddMonoid = Monoid(op=int.__add__, zero=lambda: 0)
self.assertEqual(sum(len(t) for t in texts), foldmap(len, IntAddMonoid, texts))
IntMulMonoid = Monoid(op=op.mul, zero=lambda: 1)
tally = 1
for t in texts:
tally *= len(t)
self.assertEqual(tally, foldmap(len, IntMulMonoid, texts))
StrCatMonoid = Monoid(op=text_type.__add__, zero=lambda: u"")
self.assertEqual("".join(texts), foldmap(lambda x: x, StrCatMonoid, texts))
@given(st.dictionaries(st.integers(), st.none()))
def test_filter_keys(mapping):
p = lambda k: k % 2 == 0
new = filter_keys(p, mapping)
assert new is not mapping
assert all(p(k) for k in new.keys())
assert all(k not in new for k in mapping if not p(k))
@given(st.dictionaries(st.integers(), st.integers(min_value=1)))
def test_filter_dict(mapping):
p = lambda k, v: k % v == 0
new = filter_dict(p, mapping)
assert new is not mapping
assert all(p(k, v) for k, v in new.items())
assert all(k not in new for k, v in mapping.items() if not p(k, v))
@given(st.lists(st.integers(), max_size=2) | st.lists(st.integers(), max_size=2).map(iter), st.none() | st.integers())
def test_head(iterable, default):
iter1, iter2 = tee(iterable)
iter_copy = list(iter2)
the_head = unsafe_head(iter1, default)
rest = iter_copy[1:]
if len(iter_copy) == 0:
assert the_head == default
else:
assert [the_head] + rest == iter_copy # head + rest == all
class MonoidLawTest(object):
"""Mixin giving monoid law tests."""
monoid = None
strategy = None
def test_identity(self):
@given(self.strategy)
def _test_identity(value):
m = self.monoid
assert m.op(value, m.zero()) == value # Right identity
assert m.op(m.zero(), value) == value # Left identity
_test_identity() # pylint: disable=no-value-for-parameter
def test_associativity(self): # pylint: disable=no-value-for-parameter
@given(self.strategy, self.strategy, self.strategy)
@settings(suppress_health_check=[HealthCheck.too_slow])
def _test_associativity(value1, value2, value3):
m = self.monoid
assert m.op(m.op(value1, value2), value3) == m.op(value1, m.op(value2, value3))
_test_associativity() # pylint: disable=no-value-for-parameter
def test_monoid_sum(self):
"""Test that overridden efficient m.monoid_sum(iterable) is equivalent to monoid_sum(m, iterable)"""
@given(st.lists(self.strategy))
@settings(suppress_health_check=[HealthCheck.too_slow])
def _test_msum(values):
assert monoid_sum(self.monoid, values) == self.monoid.monoid_sum(values)
_test_msum() # pylint: disable=no-value-for-parameter
class TestSetMonoid(MonoidLawTest, unittest.TestCase):
monoid = SetMonoid
strategy = st.sets(st.integers())
class TestDictMergeMonoid(MonoidLawTest, unittest.TestCase):
monoid = DictMergeMonoid(SetMonoid) # SetMonoid for convenience; it could be any valid monoid
strategy = st.dictionaries(st.integers(), st.sets(st.integers()))
class TestListMonoid(MonoidLawTest, unittest.TestCase):
monoid = ListMonoid
strategy = st.lists(st.integers())
def test_tupled():
assert tupled(op.add)((1, 2)) == op.add(1, 2)
def test_untupled():
f = lambda t: (t[0], t[1])
assert untupled(f)(1, 2) == f((1, 2))
def test_tuple_inverts_untupled():
f = lambda t: (t[0], t[1])
assert tupled(untupled(f))((1, 2)) == f((1, 2))
def test_untupled_inverts_tupled():
assert untupled(tupled(op.add))(1, 2) == op.add(1, 2)
def test_kwstarred():
def f(a, b):
return a, b
assert kwstarred(f)({'a': 1, 'b': 2}) == f(a=1, b=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment