Last active
April 27, 2021 12:50
-
-
Save syrte/cb4b51c2484b5d552dae29f0bc67d557 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ModelParam: | |
""" | |
parameter object. | |
param_fixed: | |
parameters with fixed known value | |
names_free: | |
name list of free parameters | |
param_list_to_dict(param_list): | |
transfer a list of free params to dict of all params | |
usage: | |
p0 = ModelParam({'a': 1, 'b':2, 'c':3}) | |
p1 = p0.fix(c=2).free('a', 'b') | |
p1.names_free | |
# ['a', 'b'] | |
p1.param_list_to_dict([3, 4]) | |
# {'a': 3, 'b':4, 'c':2} | |
""" | |
def __init__(self, param_fixed, names_free=[]): | |
""" | |
param_fixed: dict | |
names_free: list of strings | |
names_free will overwrite param_fixed. | |
""" | |
self.param_fixed = {**param_fixed} | |
for name in names_free: | |
self.param_fixed.pop(name, None) | |
self.names_free = [*names_free] | |
def fix(self, **param_fixed): | |
"update fixed params" | |
names_free = [name for name in self.names_free if name not in param_fixed] | |
return ModelParam({**self.param_fixed, **param_fixed}, names_free) | |
def free(self, *names_free): | |
"make params free, do not free a parameter multiple times!" | |
if [name for name in names_free if name in self.names_free]: | |
raise ValueError('duplicated params in names_free') | |
return ModelParam(self.param_fixed, [*self.names_free, *names_free]) | |
def param_list_to_dict(self, param_list): | |
""" | |
param_list: list of values | |
""" | |
param_free = dict(zip(self.names_free, param_list)) | |
return {**self.param_fixed, **param_free} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment