Skip to content

Instantly share code, notes, and snippets.

@jackdreilly
Created October 30, 2023 19:45
Show Gist options
  • Save jackdreilly/f44827a2ffdc479d0c53eb3ce563be6b to your computer and use it in GitHub Desktop.
Save jackdreilly/f44827a2ffdc479d0c53eb3ce563be6b to your computer and use it in GitHub Desktop.
from __future__ import annotations
from dataclasses import dataclass
import operator
from typing import Any, Callable
@dataclass(frozen=True)
class _Args:
args: tuple[Any, ...]
kwargs: dict[str, Any]
def resolve(self, jam: _Jambda):
if not isinstance(jam, _Jambda):
return jam
return (
jam._label(self)
if callable(jam._label)
else self.args[jam._label]
if isinstance(jam._label, int)
else jam._label.resolve(self)
)
@dataclass(frozen=True)
class _Kwarg:
key: str
default_given: bool
default: Any
def resolve(self, args: _Args):
if self.key in args.kwargs:
return args.kwargs[self.key]
if self.default_given:
return self.default
raise KeyError(f"No kwarg provided for key: {self.key}")
@classmethod
def make(cls, arg: str | None = None, **kwargs):
if len(kwargs) > 1:
raise ValueError("Can only provide 1 kwarg")
if not arg and not kwargs:
raise ValueError("Must provide either a string or kwarg w/ default value")
if arg and kwargs:
raise ValueError("Cannot provide both string and kwarg")
if arg:
return _Jambda(cls(arg, False, None))
k, v = list(kwargs.items())[0]
return _Jambda(cls(k, True, v))
@dataclass(frozen=True)
class _Jambda:
_label: Callable[[_Args], Any] | _Kwarg | int = 0
def fn(self, *a, **kw):
return _Args(a, kw).resolve(self)
def __add__(self, o):
return self._wrap(operator.add, o)
def __mul__(self, o):
return self._wrap(operator.mul, o)
def __sub__(self, o):
return self._wrap(operator.sub, o)
def __call__(self, *args: Any, **kwds: Any) -> Any:
def fn(me, *a, **b):
return me(*a, **b)
return self._wrap(fn, *args, **kwds)
def __getattr__(self, k: str):
return self._wrap(operator.attrgetter(k))
def wrap(self, fn: Callable, *a, **kw):
def c(me, *a, **b):
return fn(me, *a, **b)
return self._wrap(c, *a, **kw)
def _wrap(self, op, *a, **kw):
return _Jambda(
lambda args: op(
*map(args.resolve, (self, *a)),
**{k: args.resolve(v) for k, v in kw.items()},
)
)
jambda = _Jambda()
arg_0 = jambda
arg_1 = _Jambda(_label=1)
kwarg = _Kwarg.make
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment