Skip to content

Instantly share code, notes, and snippets.

@klieret
Last active November 19, 2019 14:23
Show Gist options
  • Save klieret/0f3b1b8a0b9b143702ae1aa662264029 to your computer and use it in GitHub Desktop.
Save klieret/0f3b1b8a0b9b143702ae1aa662264029 to your computer and use it in GitHub Desktop.
Oftentimes, we call functions over a wide range of parameter values, e.g. to generate plots for different scenarios. For multiple parameters, this leads to ugly multi-loops and cumbersome boilerplate. With this snippet it becomes one line.
from typing import Callable, Optional, Dict, Any
import itertools
from tqdm.auto import tqdm
def product_call(
fct: Callable,
static: Optional[Dict[str, Any]] = None,
multi: Optional[Dict[str, Any]] = None,
progress=False
) -> Tuple[List[Dict[str, Any]], List[Any]]:
""" Execute the same function multiple times, where we change parameters
in each call.
Kilian Lieret 2019
https://gist.github.com/klieret/0f3b1b8a0b9b143702ae1aa662264029
Example:
multiplex(fct, multi=dict(a=[1, 2, 3]), static=dict(b=4))
Will call
fct(a=1, b=4)
fct(a=2, b=4)
fct(a=3, b=4)
If a value for one of the static arguments is a string, we will try to
format it with the multi arguments, e.g. static=dict(b="{a}.pdf")
will result in
fct(a=1, b="1.pdf")
fct(a=2, b="2.pdf")
fct(a=3, b="3.pdf")
Args:
fct: The function to call
static: Dictionary of keyword arguments that are passed to the function
(but that do not chanage for each call)
multi: Dictionary of keyword arguments to lists of values, that are
passed to the function (see above)
progress: Use tqdm to draw a nice progress bar (default False)
Returns:
list of kwarg dictionaries and list of corresponding results.
"""
if static is None:
static = {}
if multi is None:
multi = {}
keys = list(multi.keys())
values = [multi[key] for key in keys] # Ordered!
products = list(itertools.product(*values))
if progress:
products = tqdm(progress)
results = []
kwargs_list = []
class FormatDefaultDict(dict):
def __missing__(self, key):
return '{' + key + '}'
for setting in products:
multi_kwargs = {keys[i]: setting[i] for i in range(len(keys))}
static_kwargs = {}
for key, value in static.items():
if isinstance(value, str):
value = value.format_map(FormatDefaultDict(multi_kwargs))
static_kwargs[key] = value
kwargs = {**multi_kwargs, **static_kwargs}
res = fct(**kwargs)
results.append(res)
kwargs_list.append(kwargs)
return kwargs_list, results
def test_product_call():
def _test(a, b, c):
return a + b
kwargs, values = product_call(
_test,
static=dict(c="{a}{b}"),
multi=dict(a=[1, 2, 3], b=[4, 5, 6])
)
assert kwargs == [
{'a': 1, 'b': 4, 'c': '14'},
{'a': 1, 'b': 5, 'c': '15'},
{'a': 1, 'b': 6, 'c': '16'},
{'a': 2, 'b': 4, 'c': '24'},
{'a': 2, 'b': 5, 'c': '25'},
{'a': 2, 'b': 6, 'c': '26'},
{'a': 3, 'b': 4, 'c': '34'},
{'a': 3, 'b': 5, 'c': '35'},
{'a': 3, 'b': 6, 'c': '36'}
]
assert values == [5, 6, 7, 6, 7, 8, 7, 8, 9]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment