Skip to content

Instantly share code, notes, and snippets.

@urigoren
Last active April 19, 2021 20:50
Show Gist options
  • Save urigoren/206851969cae30becabbd1d3c2ab526b to your computer and use it in GitHub Desktop.
Save urigoren/206851969cae30becabbd1d3c2ab526b to your computer and use it in GitHub Desktop.
Use arrow notation (>>) like Haskell to make filter, map and reduce operations more readable.
from itertools import chain
from functools import reduce
import operator
"""
Usage of this module:
<iterable> >> function3 * function2 * function1 >> aggregation
for example:
[1,2,3] >> add == 6
[1,2,3] >> unifunc(lambda x: x+1) == [2,3,4]
[1,2,3] >> unifunc(lambda x: x+1) * (lambda x: x-1) == [1,2,3]
See tests for more examples
"""
class unifunc(object):
def __init__(self, func):
if not hasattr(func, '__call__'):
raise TypeError("unifunc parameter should be callable; %r is not." % func)
object.__init__(self)
self.func = func
def __call__(self, x):
return self.func(x)
def __mul__(self, other):
f = self.func
if hasattr(other, 'func'):
g = other.func
elif hasattr(other, '__call__'):
g = other
else:
raise TypeError("unifunc parameter should be callable; %r is not." % other)
return unifunc(lambda x: f(g(x)))
def __rmul__(self, other):
g = self.func
if hasattr(other, 'func'):
f = other.func
elif hasattr(other, '__call__'):
f = other
else:
raise TypeError("unifunc parameter should be callable; %r is not." % other)
return unifunc(lambda x: f(g(x)))
def __rrshift__(self, other):
return filter(lambda x: x is not None, map(self.func, other))
def __rlshift__(self, other):
"""Essentially, a flat map"""
return filter(lambda x: x is not None, chain(*map(self.func, other)))
class bifunc(object):
def __init__(self, func):
if not hasattr(func, '__call__'):
raise TypeError("bifunc parameter should be callable; %r is not." % func)
object.__init__(self)
self.func = func
def __call__(self, *args):
return self.func(*args)
def __rrshift__(self, other):
return reduce(self.func, other)
class aggfunc(object):
def __init__(self, func):
if not hasattr(func, '__call__'):
raise TypeError("aggfunc parameter should be callable; %r is not." % func)
object.__init__(self)
self.func = func
def __call__(self, *args):
return self.func(*args)
def __rrshift__(self, other):
return self.func(other)
id = unifunc(lambda x: x)
at = unifunc(operator.itemgetter)
toString = unifunc(str)
toInt = unifunc(int)
toFloat = unifunc(float)
add = bifunc(operator.add)
mul = bifunc(operator.mul)
and_ = bifunc(operator.and_)
or_ = bifunc(operator.or_)
toList = aggfunc(list)
toSet = aggfunc(frozenset)
toDict = aggfunc(dict)
count = aggfunc(lambda s: len(list(s)))
countDistinct = aggfunc(lambda s: len(frozenset(s)))
if __name__=='__main__':
@unifunc
def plus1(v):
return v+1
@unifunc
def times2(v):
return v*2
assert (plus1(2) == 3)
assert (list(([1,2,3]>>plus1*(lambda x:x-1))) == [1,2,3])
assert (list(([1, 2, 3] >> plus1 * times2)) == [3, 5, 7])
@unifunc
def minus_plus(v):
return [v,-v]
assert (list([1, 2, 3] << minus_plus)== [1, -1, 2, -2, 3, -3])
assert ([1,2,3,4] >> bifunc(operator.add) == 10)
assert ([1,2,1,2,1]>>countDistinct == 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment