Skip to content

Instantly share code, notes, and snippets.

@phinate
Created August 2, 2023 14:25
Show Gist options
  • Save phinate/1c95a8d216d26349f77035cd0da42d1c to your computer and use it in GitHub Desktop.
Save phinate/1c95a8d216d26349f77035cd0da42d1c to your computer and use it in GitHub Desktop.
pyhf-like models as pytrees
import jax.scipy as jsp
import equinox as eqx
import jax.numpy as jnp
from jax import Array
import jax
jax.config.update("jax_enable_x64", True)
@jax.jit
def poisson_logpdf(n, lam):
return n * jnp.log(lam) - lam - jsp.special.gammaln(n + 1)
class Model(eqx.Module):
def logpdf(self, data: Array, pars: dict[str, Array] | Array) -> Array:
raise NotImplementedError
def expected_data(self, pars: dict[str, Array] | Array) -> Array:
raise NotImplementedError
class Systematic(eqx.Module):
name: str
constraint: Model
class PoissonConstraint(Model):
scaled_binwise_uncerts: Array
def __init__(self, nominal_bkg: Array, binwise_uncerts: Array) -> None:
eqx.error_if(
nominal_bkg,
nominal_bkg.shape != binwise_uncerts.shape,
f"Nominal bkg shape {nominal_bkg.shape} does not match binwise uncertainty shape {binwise_uncerts.shape}"
)
self.scaled_binwise_uncerts = binwise_uncerts / nominal_bkg
def expected_data(self, gamma: Array) -> Array:
return gamma*self.scaled_binwise_uncerts**-2
def logpdf(self, auxdata, gamma):
eqx.error_if(
gamma,
gamma.shape != self.scaled_binwise_uncerts.shape,
f"Constrained param shape {gamma.shape} does not match number of bins {self.scaled_binwise_uncerts.shape}"
)
return jnp.sum(
poisson_logpdf(auxdata, (gamma*self.scaled_binwise_uncerts**-2)),
axis=None
)
class UncorrelatedShape(Systematic):
def __init__(self, name: str, nominal_bkg: Array, binwise_uncerts: Array) -> None:
self.name = name
self.constraint = PoissonConstraint(nominal_bkg, binwise_uncerts)
class HEPDataLike(Model):
sig: Array
bkg: Array
db: Array
poi_name: str
systematic: UncorrelatedShape
def __init__(self, sig: Array, bkg: Array, db: Array, poi_name: str = "mu", nuis_name: str = "shapesys") -> None:
self.sig = sig
self.bkg = bkg
self.db = db
self.poi_name = poi_name
self.systematic = UncorrelatedShape(nuis_name, bkg, db)
def expected_data(self, pars: dict[str, Array]) -> Array:
mu, gamma = pars[self.poi_name], pars[self.systematic.name]
return mu * self.sig + gamma * self.bkg, self.systematic.constraint.expected_data(gamma)
def logpdf(self, data: Array, pars: dict[str, Array]) -> Array:
maindata, auxdata = data
main, _ = self.expected_data(pars)
main = jnp.sum(poisson_logpdf(maindata, main), axis=None)
constraint = self.systematic.constraint.logpdf(auxdata, pars[self.systematic.name])
return main + constraint
# example:
sig = jnp.array([5,10])
bkg = jnp.array([50.0, 60.0])
uncerts = jnp.array([5.0, 12.0])
model = HEPDataLike(sig, bkg, uncerts, poi_name="mu", nuis_name="shapesys")
pars = {"mu": jnp.array(1.0), "shapesys": jnp.array([1.0, 1.0])}
data = model.expected_data(pars)
# pyhf version
import pyhf
pyhf_model = pyhf.simplemodels.uncorrelated_background([5, 10], [50, 60], [5, 12])
pyhf_pars = pyhf.tensorlib.astensor([1.0, 1.0, 1.0])
pyhf_data = pyhf_model.expected_data(pyhf_pars)
assert jnp.allclose(model.logpdf(data, pars), pyhf.tensorlib.astensor(pyhf_model.logpdf(pyhf_pars, pyhf_data)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment