Skip to content

Instantly share code, notes, and snippets.

@sarkologist
Last active February 8, 2018 03:54
Show Gist options
  • Save sarkologist/a4903853748f3c7948e0df4a48b3af46 to your computer and use it in GitHub Desktop.
Save sarkologist/a4903853748f3c7948e0df4a48b3af46 to your computer and use it in GitHub Desktop.
composable folds in python
from functools import partial, wraps, reduce
class Fold:
def __init__(self, zero, update, out):
self.zero = zero
self.update = update
self.out = out
def __add__(self, other):
zero = (self.zero, other.zero)
update = lambda old, new: (self.update(old[0],new), other.update(old[1],new))
out = lambda x: self.out(x[0])(other.out(x[1]))
return Fold(zero, update, out)
def __rmul__(self, other):
return Fold(self.zero, self.update, lambda x: other(self.out(x)))
def run_fold(f, xs):
acc = f.zero
for x in xs:
acc = f.update(acc,x)
return f.out(acc)
def curry(f):
@wraps(f)
def _(arg):
try:
return f(arg)
except TypeError:
return curry(wraps(f)(partial(f, arg)))
return _
from comparison.fold import Fold, run_fold, curry
def pair(x,y):
return (x,y)
def triple(x,y,z):
return (x,y,z)
def test_curry():
assert curry(pair)(1)(2) == (1,2)
assert curry(triple)(1)(2)(3) == (1,2,3)
def test_fold():
summ = Fold(0,lambda x, y: x+y, lambda x: x)
product = Fold(1,lambda x, y: x*y, lambda x: x)
subtract = Fold(0,lambda x, y: x-y, lambda x: x)
assert run_fold(curry(pair) * summ + product, [1,2,3,4]) == (10,24)
assert run_fold(curry(triple) * summ + product + subtract, [1,2,3,4]) == (10,24,-10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment