Skip to content

Instantly share code, notes, and snippets.

@syrte
Last active April 27, 2021 12:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syrte/cb4b51c2484b5d552dae29f0bc67d557 to your computer and use it in GitHub Desktop.
Save syrte/cb4b51c2484b5d552dae29f0bc67d557 to your computer and use it in GitHub Desktop.
class ModelParam:
"""
parameter object.
param_fixed:
parameters with fixed known value
names_free:
name list of free parameters
param_list_to_dict(param_list):
transfer a list of free params to dict of all params
usage:
p0 = ModelParam({'a': 1, 'b':2, 'c':3})
p1 = p0.fix(c=2).free('a', 'b')
p1.names_free
# ['a', 'b']
p1.param_list_to_dict([3, 4])
# {'a': 3, 'b':4, 'c':2}
"""
def __init__(self, param_fixed, names_free=[]):
"""
param_fixed: dict
names_free: list of strings
names_free will overwrite param_fixed.
"""
self.param_fixed = {**param_fixed}
for name in names_free:
self.param_fixed.pop(name, None)
self.names_free = [*names_free]
def fix(self, **param_fixed):
"update fixed params"
names_free = [name for name in self.names_free if name not in param_fixed]
return ModelParam({**self.param_fixed, **param_fixed}, names_free)
def free(self, *names_free):
"make params free, do not free a parameter multiple times!"
if [name for name in names_free if name in self.names_free]:
raise ValueError('duplicated params in names_free')
return ModelParam(self.param_fixed, [*self.names_free, *names_free])
def param_list_to_dict(self, param_list):
"""
param_list: list of values
"""
param_free = dict(zip(self.names_free, param_list))
return {**self.param_fixed, **param_free}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment