Created
January 19, 2023 21:18
-
-
Save davipatti/481c08887eb5ec0e9fef964b77422058 to your computer and use it in GitHub Desktop.
[pymc parameter reporter] A small class for reporting the status of a pymc model, while it samples. #pymc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0 0 a= 1.233 c= 0.036 | |
0 100 a=-0.130 c= 0.014 | |
0 200 a=-0.260 c=-1.752 | |
0 300 a= 1.626 c=-1.839 | |
0 400 a= 0.291 c= 0.438 | |
0 500 a=-0.091 c=-0.126 | |
0 600 a= 0.383 c= 0.046 | |
0 700 a= 0.959 c= 0.280 | |
0 800 a=-1.368 c= 1.264 | |
0 900 a= 0.415 c=-0.102 | |
0 1000 a=-1.474 c= 0.376 | |
0 1100 a= 0.108 c=-0.429 | |
0 1200 a= 1.258 c=-0.653 | |
0 1300 a=-0.666 c=-0.250 | |
0 1400 a=-0.875 c= 0.049 | |
0 1500 a=-1.155 c= 0.452 | |
0 1600 a= 1.747 c=-0.838 | |
0 1700 a= 0.403 c=-0.868 | |
0 1800 a= 0.817 c=-1.692 | |
0 1900 a=-0.973 c= 1.650 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymc as pm | |
class ParameterReporter: | |
def __init__( | |
self, n_chains: int, n: int = 100, fname_prefix: str = "chain" | |
) -> None: | |
""" | |
Report parameters every n draws for n_chains. | |
Args: | |
n_chains: Should match number of chains in call to pm.sample. | |
n: Report parameters every n draws. | |
fname_prefix: A file with this prefix is generated for each chain. | |
""" | |
self.n_chains = n_chains | |
self.n = n | |
self.fname_prefix = fname_prefix | |
def __repr__(self) -> str: | |
return f"ParameterReporter(n_chains={self.n_chains}, n={self.n}, fname_prefix={self.fname_prefix})" | |
def __enter__(self) -> None: | |
self.files = [ | |
open(f"{self.fname_prefix}{c}.txt", "w") for c in range(self.n_chains) | |
] | |
def __exit__(self, *args, **kwargs) -> None: | |
for fobj in self.files: | |
fobj.close() | |
def write(self, trace, draw) -> None: | |
""" | |
Write parameters to a file. | |
Implementation note: | |
The pymc.sample callback passes 'trace' and 'draw' keyword arguments. | |
""" | |
if draw.draw_idx % self.n == 0: | |
pts = sorted(f"{k}={v: .3f}" for k, v in draw.point.items() if v.ndim == 0) | |
line = f"{draw.chain:2d} {draw.draw_idx:5d} {' '.join(pts)}\n" | |
self.files[draw.chain].write(line) | |
# Simple usage | |
############## | |
with pm.Model() as model: | |
pm.Normal("a", 0, 1) | |
pm.Normal("b", 0, 1, shape=(10, 12)) # Multivariate variables aren't reported | |
pm.Normal("c", 0, 1) | |
pr = ParameterReporter(n_chains=4, n=100) | |
with model, pr: | |
pm.sample(callback=pr.write, cores=4) | |
# Creates chain{0-4}.txt | |
# Column 1 is the chain index | |
# Column 2 is the draw index | |
# Remaining columns are values of variables |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment