Skip to content

Instantly share code, notes, and snippets.

@a-recknagel
Last active October 17, 2023 20:04
Show Gist options
  • Save a-recknagel/c9744b34d0e399da4e4cbc93d56c082a to your computer and use it in GitHub Desktop.
Save a-recknagel/c9744b34d0e399da4e4cbc93d56c082a to your computer and use it in GitHub Desktop.
bootleg inheritance
from inspect import signature
from contextvars import ContextVar
from typing import Callable
from unittest.mock import sentinel
SENTINEL = sentinel.UNSET
ctx: ContextVar[dict] = ContextVar('extends_ctx')
ctx.set({})
class ExtendingCallable:
extended_kwargs: dict
# doesn't actually exist, setattr and getattr just redirect to the correct context var
# for this instance
def __init__(self, f: callable, f_: callable):
self._child_func = f
self._parent_func = f_
self._child_parameters = signature(self._child_func).parameters
self._parent_parameters = signature(self._parent_func).parameters
ctx.get()[self._child_func.__name__] = SENTINEL
# pseudo-functools.wraps
self.__module__ = self._child_func.__module__
self.__name__ = self._child_func.__name__
self.__qualname__ = self._child_func.__qualname__
self.__doc__ = self._child_func.__doc__
self.__call__.__func__.__annotations__ = {
**self._parent_func.__annotations__,
**self._child_func.__annotations__,
}
def __call__(self, *args, **kwargs):
old_kwargs = self.extended_kwargs
self.extended_kwargs = {k: v for k, v in kwargs.items() if k in self._parent_parameters}
ret = self._child_func(*args, **{k: v for k, v in kwargs.items() if k in self._child_parameters})
self.extended_kwargs = old_kwargs
return ret
def __getattribute__(self, item: str):
if item != "extended_kwargs":
return object.__getattribute__(self, item)
return ctx.get()[self._child_func.__name__]
def __setattr__(self, key: str, value):
if key != "extended_kwargs":
return object.__setattr__(self, key, value)
ctx.get()[self._child_func.__name__] = value
def extends(f_: callable) -> Callable[[callable], ExtendingCallable]:
"""Extend the signature of some function `f` with that of an existing one `f_`.
All arguments that match the names of parameters from `f_` will be stored as a
dictionary per call in an attribute `extended_kwargs` on `f`. As a consequence,
signatures on `f_` involving `/` won't work.
"""
def outer(f: callable) -> ExtendingCallable:
return ExtendingCallable(f, f_)
return outer
def foo(a: int, b: int, c: float) -> None:
print(f"foo: {a=}, {b=}, {c=}")
@extends(foo)
def bar(c: float, d: str, e: str) -> bool:
"""A function which extends `foo` in some way."""
if c == "c":
bar(**{**bar.extended_kwargs, "c": "c_", "d": d, "e": e})
foo(**bar.extended_kwargs)
print(f"bar: {c=}, {d=}, {e=}")
return True
bar(a="a", b="b", c="c", d="d", e="e")
# foo: a='a', b='b', c='c_'
# bar: c='c_', d='d', e='e'
# foo: a='a', b='b', c='c'
# bar: c='c', d='d', e='e'
@a-recknagel
Copy link
Author

now supports recursive calls of the extending function. I also added contextVars to access the state, but didn't test if it actually works if called asynchronously.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment