Last active
January 2, 2021 07:53
-
-
Save elazarg/99e6d1450c3c85aa74317c1b01767d86 to your computer and use it in GitHub Desktop.
Python3: A decorator `fresh_defaults` that allows mutable default arguments for functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
from copy import deepcopy | |
from types import FunctionType | |
def fresh_defaults(f): | |
''' | |
A decorator that allows mutable default arguments (such as lists) | |
to be used without worrying about changes across calls. | |
This is done simply by copying them before each call. | |
The defaults must be copyable, i.e. not types such as 'module, method, | |
stack trace, stack frame, file, socket, window, array, or any similar types' | |
>>> @fresh_defaults | |
def foo(a=[], *, b=set()): | |
a.append(1) | |
b.add(1) | |
print(a, b) | |
>>> foo() | |
[1] {1} | |
>>> foo() | |
[1] {1} | |
''' | |
if _all_immutable(f.__defaults__) and _all_immutable(f.__kwdefaults__.values()): | |
return f | |
f = _clone_function(f) | |
# This copy is done as an assertion, to catch errors at definition time | |
defaults = deepcopy((f.__defaults__, f.__kwdefaults__)) | |
@wraps(f) | |
def wrapper(*args, **kwargs): | |
f.__defaults__, f.__kwdefaults__ = deepcopy(defaults) | |
return f(*args, **kwargs) | |
return wrapper | |
def _clone_function(f): | |
# we do not copy the environment. It does not make sense to do so. | |
# based on http://stackoverflow.com/a/13503277/2289509 | |
new_f = FunctionType(f.__code__, f.__globals__, name=f.__name__, | |
argdefs=deepcopy(f.__defaults__), | |
closure=f.__closure__) | |
new_f = functools.update_wrapper(new_f, f) | |
new_f.__kwdefaults__ = deepcopy(f.__kwdefaults__) | |
del f.__wrapped__ | |
return new_f | |
def _all_immutable(xs): | |
LEAVES = (str, int, bool, float, type(None)) | |
BRANCHES = (tuple, frozenset) | |
if xs is None: | |
return True | |
return all(isinstance(x, LEAVES) | |
or isinstance(x, BRANCHES) and _all_immutable(x) | |
for x in xs) | |
def test_positional(): | |
def foo(a=[1]): | |
a.append(a[-1]+1) | |
return a[-1] | |
@fresh_defaults | |
def foo_fresh(a=[1]): | |
a.append(a[-1]+1) | |
return a[-1] | |
assert foo() != foo() | |
assert foo_fresh() == foo_fresh() | |
def test_keywords(): | |
def foo(*, a=[1]): | |
a.append(a[-1]+1) | |
return a[-1] | |
foo_fresh1 = fresh_defaults(foo) | |
@fresh_defaults | |
def foo_fresh(*, a=[1]): | |
a.append(a[-1]+1) | |
return a[-1] | |
assert foo() != foo() | |
assert foo_fresh() == foo_fresh() | |
#tests copying the function | |
assert foo_fresh1() == foo_fresh1() | |
def test_unchanged(): | |
def foo(a=((1,),2), b=None, *, c=frozenset([1.0, 'hello'])): | |
pass | |
def bar(a=[(1,),2], b=None, *, c=frozenset([1.0, 'hello'])): | |
pass | |
def car(a=((1,),2), b=None, *, c=({}, 'hello')): | |
pass | |
assert fresh_defaults(foo) is foo | |
assert fresh_defaults(bar) is not bar | |
assert fresh_defaults(car) is not car | |
if __name__ == '__main__': | |
test_positional() | |
test_keywords() | |
test_unchanged() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment