-
-
Save metavee/42b394601dc0014da53d195b7e06cd1b to your computer and use it in GitHub Desktop.
Decorator to make it safe to use mutable default args.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from functools import wraps | |
import inspect | |
from typing import Callable | |
def safe_defaults(f: Callable) -> Callable: | |
""" | |
Decorate a function to allow safe usage of mutable default parameters. | |
Each time the function is invoked with defaults, a fresh copy of the mutable arguments will be passed. | |
Lists, dicts and sets are supported. | |
E.g., instead of this: | |
>>> def my_function(arg: Optional[List] = None): | |
... if arg is None: | |
... arg = [] | |
... | |
... pass | |
You can write: | |
>>> @safe_defaults | |
... def my_function(arg: List = []): | |
... pass | |
Adapted from this Stack Overflow answer by @303931/lucas-wiman | |
https://stackoverflow.com/a/69170441 | |
Parameters | |
---------- | |
f: Callable | |
Function to decorate. | |
Returns | |
------- | |
Callable | |
Wrapped function. | |
""" | |
sig = inspect.signature(f) | |
default_values = {} | |
# inspect defaults in type signature | |
for param_name, param in sig.parameters.items(): | |
if type(param.default) in (list, set, dict): | |
# save a copy that will never be modified | |
default_values[param_name] = deepcopy(param.default) | |
@wraps(f) | |
def wrapper(*args, **kwds): | |
bound = sig.bind(*args, **kwds) | |
for param_name, param_value in default_values.items(): | |
# copy over our default value if not specified by caller | |
if param_name not in bound.arguments: | |
kwds[param_name] = deepcopy(param_value) | |
bound.apply_defaults() | |
return f(*args, **kwds) | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment