Skip to content

Instantly share code, notes, and snippets.

@charmoniumQ
Created September 26, 2023 19:24
Show Gist options
  • Save charmoniumQ/0f04ef83cb687f49edb1e1cb3910c182 to your computer and use it in GitHub Desktop.
Save charmoniumQ/0f04ef83cb687f49edb1e1cb3910c182 to your computer and use it in GitHub Desktop.
Partially pooled linear regression when x[0] == 0 for all instances where x[1] == "specific-class"
import pathlib
import datetime
import pymc
import arviz
import runner
random_seed = 0
cache = pathlib.Path(".cache")
df = runner.get_results()
with pymc.Model(coords={
"data": df.index,
"workload": df.workload.cat.categories,
"collector": df.collector.cat.categories,
}) as model:
workload_idx = pymc.Data("workload_idx", df.workload.cat.codes, dims="data", mutable=False)
collector_idx = pymc.Data("collector_idx", df.collector.cat.codes, dims="data", mutable=False)
pooled_workload_runtime = pymc.Exponential("pooled_workload_runtime", 1/50)
pooled_workload_runtime_stddev = pymc.Exponential("workload_runtime_var", 1/1)
workload_runtime = pymc.Normal(
"workload_runtime",
mu=pooled_workload_runtime,
sigma=pooled_workload_runtime_stddev,
dims="workload",
)
pooled_collector_overhead_per_op = pymc.Exponential("pooled_collector_overhead_per_op", 1/1e-4)
pooled_collector_overhead_per_op_stddev = pymc.Exponential("pooled_collector_overhead_per_op_var", 1/1e-5)
collector_overhead_per_op = pymc.Normal(
"collector_overhead_per_op",
mu=pooled_collector_overhead_per_op,
sigma=pooled_collector_overhead_per_op_stddev,
dims="collector",
)
est_runtime = pymc.Deterministic(
"est_runtime",
workload_runtime[workload_idx] + df.n_ops.values * collector_overhead_per_op[collector_idx],
dims="data",
)
runtime_std = pymc.Exponential("runtime_std", 1/1e-1)
runtime = pymc.Normal("runtime", mu=est_runtime, sigma=runtime_std, observed=df.walltime, dims="data")
graph = pymc.model_to_graphviz(model)
graph.render(outfile="output/model.png")
cache_file = cache / "prior.hdf5"
if cache_file.exists():
idata = arviz.from_netcdf(cache_file)
else:
with model:
idata = pymc.sample_prior_predictive(
random_seed=random_seed,
)
axes = arviz.plot_ppc(idata, var_names="runtime", group="prior")
axes.figure.savefig("output/prior_runtime.png")
cache_file = cache / "trace.hdf5"
if cache_file.exists():
idata = arviz.from_netcdf(cache_file)
else:
with model:
idata = pymc.sample(
random_seed=random_seed,
progressbar=True,
)
arviz.to_netcdf(idata, cache_file)
# check convergence diagnostics
assert all(arviz.rhat(idata) < 1.03)
axes = arviz.plot_trace(idata)
axes.ravel()[0].figure.savefig("output/trace.png")
axes = arviz.plot_posterior(idata, var_names=["pooled_workload_runtime", "pooled_collector_overhead_per_op", "runtime_std"])
axes.ravel()[0].figure.savefig("output/global_posteriors.png")
axes = arviz.plot_forest(idata, var_names="workload_runtime")
axes.ravel()[0].figure.savefig("output/workload_posteriors.png")
axes = arviz.plot_forest(idata, var_names="collector_overhead_per_op")
axes.ravel()[0].figure.savefig("output/collector_posteriors.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment