Last active
April 8, 2024 15:32
-
-
Save innateessence/7395fd574c1d4c1382a8f38adaa6eae4 to your computer and use it in GitHub Desktop.
chainable.py - Implement pipe operator + currying
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
''' | |
Just a fun project one night after work. | |
I enjoy this syntax. | |
''' | |
from typing import Any | |
from collections.abc import Callable | |
class Chainable: | |
""" | |
This class is used to wrap a function and its arguments in a chainable manner. | |
It is also used to overload the right shift operator (>>) to chain the functions together. | |
""" | |
def __init__(self, func, *args): | |
self.func = func | |
self.func_name = func.__name__ | |
self.args = args | |
self.retval = func(*args) | |
def __call__(self, *args): | |
new_args = args + self.args | |
self.retval = self.func(*new_args) | |
cur = Chainable(self.func, *new_args) | |
cur.retval = self.retval | |
return cur | |
def __rshift__(self, other): | |
left = self | |
right = other | |
if not isinstance(right, Callable): | |
raise ValueError("Right-hand side must be a callable") | |
return right(left.retval) | |
def __repr__(self): | |
args_str = ", ".join(map(str, self.args)).replace("\n", r"\n") | |
return f"<{self.__class__.__name__} {self.func_name}({args_str})>" | |
@property | |
def value(self): | |
return self.retval | |
# Make a decorator to wrap the function with Chainable class | |
def chainable(func): | |
def wrapper(*args): | |
return Chainable(func, *args) | |
return wrapper | |
@chainable | |
def add(*args) -> int | float: | |
return sum(args) | |
@chainable | |
def sub(*args) -> int | float: | |
if len(args) == 0: | |
raise ValueError("At least one argument is required") | |
return args[0] - sum(args[1:]) | |
@chainable | |
def mul(*args) -> int | float: | |
retval = 1 | |
for arg in args: | |
retval *= arg | |
return retval | |
@chainable | |
def div(*args) -> int | float: | |
if len(args) == 0: | |
raise ValueError("At least one argument is required") | |
retval = args[0] | |
for arg in args[1:]: | |
retval /= arg | |
return retval | |
@chainable | |
def pow(*args) -> int | float: | |
if len(args) == 0: | |
raise ValueError("At least one argument is required") | |
retval = args[0] | |
for arg in args[1:]: | |
retval **= arg | |
return retval | |
def toValue(x: Any): | |
return x | |
class ShellError(Exception): | |
pass | |
@chainable | |
def sh(cmd: str, stdin=None): | |
import subprocess as sp | |
proc = sp.Popen( | |
cmd, | |
shell=True, | |
stdout=sp.PIPE, | |
stderr=sp.PIPE, | |
stdin=sp.PIPE, | |
encoding="utf-8", | |
) | |
stdout, stderr = proc.communicate(input=stdin) | |
if stderr: | |
raise ShellError(stderr) | |
return stdout | |
if __name__ == "__main__": | |
assert add(1, 2, 4).value == 7 | |
assert add(1)(2)(4).value == 7 | |
assert add(1) >> add(2) >> add(4) >> toValue == 7 | |
assert add(1)(2) >> add(2)(2) >> toValue == 7 | |
assert add(1)(2) >> add(2)(2) >> sub(2) >> toValue == 5 | |
assert add(4) >> add(4) >> sub(5) >> mul(2) >> toValue == 6 | |
assert add(4) >> add(4) >> sub(5) >> mul(2) >> div(3) >> toValue == 2 | |
print("All tests passed!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment