Skip to content

Instantly share code, notes, and snippets.

@gvx
Created October 23, 2020 14:11
Show Gist options
  • Save gvx/e57f0f6babdd56c6d1ef5f1787ab4c7e to your computer and use it in GitHub Desktop.
Save gvx/e57f0f6babdd56c6d1ef5f1787ab4c7e to your computer and use it in GitHub Desktop.
Create closures in Python without default arguments
import dis
from typing import Iterable, TypeVar, Any
from collections.abc import Callable
from types import CodeType
F = TypeVar('F', bound=Callable)
def extract_mapping(names: tuple[str, ...], mapping: dict[str, int]) -> dict[int, int]:
return {names.index(k): v for k, v in mapping.items() if k in names}
LOAD_CONST = dis.opmap['LOAD_CONST']
def new_opcodes(code: CodeType, global_overrides: dict[int, int], enclosing_overrides: dict[int, int]) -> Iterable[int]:
for opcode in dis.get_instructions(code):
if opcode.opname in ('LOAD_NAME', 'LOAD_GLOBAL') and opcode.arg in global_overrides:
yield LOAD_CONST
yield global_overrides[opcode.arg]
elif opcode.opname == 'LOAD_DEREF' and opcode.arg in enclosing_overrides:
yield LOAD_CONST
yield enclosing_overrides[opcode.arg]
else:
yield opcode.opcode
yield opcode.arg or 0
def closure(**kwargs: Any) -> Callable[[F], F]:
def _closure(f: F) -> F:
code = f.__code__
constant_indexes = {k: i for i, k in enumerate(kwargs, len(code.co_consts))}
global_overrides = extract_mapping(code.co_names, constant_indexes)
enclosing_overrides = extract_mapping(code.co_cellvars + code.co_freevars, constant_indexes)
if __debug__:
unused = {k for k, v in constant_indexes.items() if v not in global_overrides.values() and v not in enclosing_overrides.values()}
assert not unused, f'some variables ({", ".join(unused)}) are defined but not used in the function {f.__qualname__}'
f.__code__ = code.replace(
co_consts=code.co_consts + tuple(kwargs.values()),
co_code=bytes(new_opcodes(code, global_overrides, enclosing_overrides)))
return f
return _closure
from closure_decorator import closure
l = []
for i in range(10):
@closure(i=i)
def foo() -> None:
print(i)
l.append(foo)
for z in l:
z()
# prints 0 ... 9 instead of all nines!
# z(i=9) would have raised a type error
# allow **kwargs without conflicts:
@closure(collector_type=dict)
def collect_some_args(**kwargs):
return collector_type(kwargs)
print(collect_some_args(collector_type=list))
# prints {'collector_type': <class 'list'>}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment