Skip to content

Instantly share code, notes, and snippets.

@davipatti
Created January 19, 2023 21:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davipatti/481c08887eb5ec0e9fef964b77422058 to your computer and use it in GitHub Desktop.
Save davipatti/481c08887eb5ec0e9fef964b77422058 to your computer and use it in GitHub Desktop.
[pymc parameter reporter] A small class for reporting the status of a pymc model, while it samples. #pymc
0 0 a= 1.233 c= 0.036
0 100 a=-0.130 c= 0.014
0 200 a=-0.260 c=-1.752
0 300 a= 1.626 c=-1.839
0 400 a= 0.291 c= 0.438
0 500 a=-0.091 c=-0.126
0 600 a= 0.383 c= 0.046
0 700 a= 0.959 c= 0.280
0 800 a=-1.368 c= 1.264
0 900 a= 0.415 c=-0.102
0 1000 a=-1.474 c= 0.376
0 1100 a= 0.108 c=-0.429
0 1200 a= 1.258 c=-0.653
0 1300 a=-0.666 c=-0.250
0 1400 a=-0.875 c= 0.049
0 1500 a=-1.155 c= 0.452
0 1600 a= 1.747 c=-0.838
0 1700 a= 0.403 c=-0.868
0 1800 a= 0.817 c=-1.692
0 1900 a=-0.973 c= 1.650
import pymc as pm
class ParameterReporter:
def __init__(
self, n_chains: int, n: int = 100, fname_prefix: str = "chain"
) -> None:
"""
Report parameters every n draws for n_chains.
Args:
n_chains: Should match number of chains in call to pm.sample.
n: Report parameters every n draws.
fname_prefix: A file with this prefix is generated for each chain.
"""
self.n_chains = n_chains
self.n = n
self.fname_prefix = fname_prefix
def __repr__(self) -> str:
return f"ParameterReporter(n_chains={self.n_chains}, n={self.n}, fname_prefix={self.fname_prefix})"
def __enter__(self) -> None:
self.files = [
open(f"{self.fname_prefix}{c}.txt", "w") for c in range(self.n_chains)
]
def __exit__(self, *args, **kwargs) -> None:
for fobj in self.files:
fobj.close()
def write(self, trace, draw) -> None:
"""
Write parameters to a file.
Implementation note:
The pymc.sample callback passes 'trace' and 'draw' keyword arguments.
"""
if draw.draw_idx % self.n == 0:
pts = sorted(f"{k}={v: .3f}" for k, v in draw.point.items() if v.ndim == 0)
line = f"{draw.chain:2d} {draw.draw_idx:5d} {' '.join(pts)}\n"
self.files[draw.chain].write(line)
# Simple usage
##############
with pm.Model() as model:
pm.Normal("a", 0, 1)
pm.Normal("b", 0, 1, shape=(10, 12)) # Multivariate variables aren't reported
pm.Normal("c", 0, 1)
pr = ParameterReporter(n_chains=4, n=100)
with model, pr:
pm.sample(callback=pr.write, cores=4)
# Creates chain{0-4}.txt
# Column 1 is the chain index
# Column 2 is the draw index
# Remaining columns are values of variables
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment