Skip to content

Instantly share code, notes, and snippets.

@MischaPanch
Last active December 17, 2021 10:51
Show Gist options
  • Save MischaPanch/30b25d82093cdef6577146af75badcff to your computer and use it in GitHub Desktop.
Save MischaPanch/30b25d82093cdef6577146af75badcff to your computer and use it in GitHub Desktop.
Allowing overloading of operators like +, *, @ and so on for functions in python
import logging
import operator as o
from typing import Callable, Union, Any, Type
log = logging.getLogger()
class _FunctionWrapper:
def __init__(self, function: Callable, name: str = None):
self.function = function
self.__name__ = function.__name__ if name is None else name
def __call__(self, *args, **kwargs):
return self.function(*args, **kwargs)
def __repr__(self):
return self.__name__
def __str__(self):
return self.__name__
class Numerical(_FunctionWrapper):
"""
Using this as decorator allows standard numerical operators to be used for combining functions with other callables
or with any non-callable objects to create new callables. The composed object will have a meaningful __name__
and representation that are useful for inspection.
Once a function has been wrapped by this decorator, operators can be used on all other objects, provided that
the wrapped function is the first in the sequence of operations
Example composition:
>>> import numpy as np
>>> @Numerical
... def f(x):
... return 2 * x
>>> def g(x):
... return x**2
>>> c = f + g
>>> c.__name__
'f + g'
>>> c = c / g
>>> c.__name__
'(f + g) / g'
>>> c = (c + np.array([1, 2])) * 5
>>> c.__name__
'((f + g) / g + [1 2]) * 5'
>>> c(10)
array([11., 16.])
Example comparison:
>>> d = f < g
>>> d.__name__
'f < g'
>>> d(10)
True
>>> d(1)
False
>>> d(np.array([1, 10]))
array([False, True])
"""
def __add__(self, other: Union[Callable, Any]):
return _get_apply_binary_operator(self, other, o.add, "+", add_braces=False)
def __sub__(self, other):
return _get_apply_binary_operator(self, other, o.sub, "-")
def __rmul__(self, other):
return _get_apply_binary_operator(other, self, o.mul, "*")
def __mul__(self, other):
return _get_apply_binary_operator(self, other, o.mul, "*")
def __abs__(self):
return _get_apply_unary_operator(self, o.abs, name=f"|{self.__name__}|")
def __neg__(self):
return _get_apply_unary_operator(self, o.neg, operator_symbol="-")
def __pos__(self):
return _get_apply_unary_operator(self, o.pos, operator_symbol="+")
def __matmul__(self, other):
return _get_apply_binary_operator(self, other, o.matmul, "@")
def __rmatmul__(self, other):
return _get_apply_binary_operator(other, self, o.matmul, "@")
def __floordiv__(self, other):
return _get_apply_binary_operator(self, other, o.floordiv, "//")
def __truediv__(self, other):
return _get_apply_binary_operator(self, other, o.truediv, "/")
def __pow__(self, other, modulo=None):
return _get_apply_binary_operator(self, other, o.pow, "**")
def __divmod__(self, other):
return _get_apply_binary_operator(self, other, o.mod, "%")
def __le__(self, other):
return _get_apply_binary_operator(self, other, o.le, "<=")
def __lt__(self, other):
return _get_apply_binary_operator(self, other, o.lt, "<")
def __ge__(self, other):
return _get_apply_binary_operator(self, other, o.ge, ">=")
def __gt__(self, other):
return _get_apply_binary_operator(self, other, o.gt, ">")
class Boolean(_FunctionWrapper):
"""
Using this as decorator allows standard logical operators to be used for combining functions with other callables
or with any non-callable objects to create new callables. The composed object will have a meaningful __name__
and representation that are useful for inspection.
Once a function has been wrapped by this decorator, operators can be used on all other objects, provided that
the wrapped function is the first in the sequence of operations.
Example:
>>> @Boolean
... def smaller2(x):
... return x < 2
>>> def divisible_by4(x):
... return x % 4 == 0
>>> and_composite = smaller2 & divisible_by4
>>> and_composite.__name__
'smaller2 & divisible_by4'
>>> and_composite(4)
False
>>> and_composite(-4)
True
"""
def __and__(self, other):
return _get_apply_binary_operator(
self, other, o.and_, "&", wrapper_class=Boolean
)
def __or__(self, other):
return _get_apply_binary_operator(
self, other, o.or_, "|", wrapper_class=Boolean
)
def __xor__(self, other):
return _get_apply_binary_operator(
self, other, o.xor, "^", wrapper_class=Boolean
)
def __invert__(self):
return _get_apply_unary_operator(
self, o.inv, operator_symbol="~", wrapper_class=Boolean
)
def _get_name(f: Union[Callable, Any]):
if not isinstance(f, Callable):
return str(f)
return f.__name__
def _to_callable(f: Union[Callable, Any]) -> Callable:
return f if isinstance(f, Callable) else lambda *args, **kwargs: f
def _get_apply_binary_operator(
f1: Union[Callable, Any],
f2: Union[Callable, Any],
operator: Callable[[Any, Any], Any],
operator_symbol: str,
add_braces=True,
wrapper_class: Type[_FunctionWrapper] = Numerical,
):
"""
Returns an instance of the wrapper_class performing x -> operator(f1(x), f2(x)) where if some f is not
a callable, the value of f is taken instead of f(x)
:param f1:
:param f2:
:param operator:
:param operator_symbol:
:param add_braces: whether to add braces around the function name in the result's name
:param wrapper_class:
:return:
"""
def maybe_add_braces(n: str):
if add_braces and " " in n:
return f"({n})"
return n
name1, name2 = maybe_add_braces(_get_name(f1)), maybe_add_braces(_get_name(f2))
f1, f2 = _to_callable(f1), _to_callable(f2)
def composed_function(*args, **kwargs):
res1, res2 = f1(*args, **kwargs), f2(*args, **kwargs)
return operator(res1, res2)
composed_name = f"{name1} {operator_symbol} {name2}"
return wrapper_class(composed_function, name=composed_name)
def _get_apply_unary_operator(
f: Callable,
operator: Callable[[Any], Any],
name: str = None,
operator_symbol: str = None,
add_braces=True,
wrapper_class: Type[_FunctionWrapper] = Numerical,
):
"""
Returns an instance of the wrapper_class performing x -> operator(f(x)) where if f is not
a callable, the value of f is taken instead of f(x)
:param f:
:param operator:
:param name: if given, operator_symbol and add_braces will be ignored
:param operator_symbol:
:param add_braces: whether to add braces around the function name in the result's name
:param wrapper_class:
:return:
"""
def operator_applied(*args, **kwargs):
return operator(f(*args, **kwargs))
if name is not None:
if add_braces:
log.debug(
f"Ignoring add_braces=True b/c explicit name was provided: {name}"
)
wrapper_class(operator_applied, name=name)
if operator_symbol is None:
raise ValueError(
"operator_symbol cannot be None when name is not provided explicitly"
)
name = f.__name__
if add_braces and " " in name:
name = f"({name})"
name = operator_symbol + name
return wrapper_class(operator_applied, name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment