Skip to content

Instantly share code, notes, and snippets.

@elazarg
Last active January 2, 2021 07:53
Show Gist options
  • Save elazarg/99e6d1450c3c85aa74317c1b01767d86 to your computer and use it in GitHub Desktop.
Save elazarg/99e6d1450c3c85aa74317c1b01767d86 to your computer and use it in GitHub Desktop.
Python3: A decorator `fresh_defaults` that allows mutable default arguments for functions
from functools import wraps
from copy import deepcopy
from types import FunctionType
def fresh_defaults(f):
'''
A decorator that allows mutable default arguments (such as lists)
to be used without worrying about changes across calls.
This is done simply by copying them before each call.
The defaults must be copyable, i.e. not types such as 'module, method,
stack trace, stack frame, file, socket, window, array, or any similar types'
>>> @fresh_defaults
def foo(a=[], *, b=set()):
a.append(1)
b.add(1)
print(a, b)
>>> foo()
[1] {1}
>>> foo()
[1] {1}
'''
if _all_immutable(f.__defaults__) and _all_immutable(f.__kwdefaults__.values()):
return f
f = _clone_function(f)
# This copy is done as an assertion, to catch errors at definition time
defaults = deepcopy((f.__defaults__, f.__kwdefaults__))
@wraps(f)
def wrapper(*args, **kwargs):
f.__defaults__, f.__kwdefaults__ = deepcopy(defaults)
return f(*args, **kwargs)
return wrapper
def _clone_function(f):
# we do not copy the environment. It does not make sense to do so.
# based on http://stackoverflow.com/a/13503277/2289509
new_f = FunctionType(f.__code__, f.__globals__, name=f.__name__,
argdefs=deepcopy(f.__defaults__),
closure=f.__closure__)
new_f = functools.update_wrapper(new_f, f)
new_f.__kwdefaults__ = deepcopy(f.__kwdefaults__)
del f.__wrapped__
return new_f
def _all_immutable(xs):
LEAVES = (str, int, bool, float, type(None))
BRANCHES = (tuple, frozenset)
if xs is None:
return True
return all(isinstance(x, LEAVES)
or isinstance(x, BRANCHES) and _all_immutable(x)
for x in xs)
def test_positional():
def foo(a=[1]):
a.append(a[-1]+1)
return a[-1]
@fresh_defaults
def foo_fresh(a=[1]):
a.append(a[-1]+1)
return a[-1]
assert foo() != foo()
assert foo_fresh() == foo_fresh()
def test_keywords():
def foo(*, a=[1]):
a.append(a[-1]+1)
return a[-1]
foo_fresh1 = fresh_defaults(foo)
@fresh_defaults
def foo_fresh(*, a=[1]):
a.append(a[-1]+1)
return a[-1]
assert foo() != foo()
assert foo_fresh() == foo_fresh()
#tests copying the function
assert foo_fresh1() == foo_fresh1()
def test_unchanged():
def foo(a=((1,),2), b=None, *, c=frozenset([1.0, 'hello'])):
pass
def bar(a=[(1,),2], b=None, *, c=frozenset([1.0, 'hello'])):
pass
def car(a=((1,),2), b=None, *, c=({}, 'hello')):
pass
assert fresh_defaults(foo) is foo
assert fresh_defaults(bar) is not bar
assert fresh_defaults(car) is not car
if __name__ == '__main__':
test_positional()
test_keywords()
test_unchanged()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment