Skip to content

Instantly share code, notes, and snippets.

@davidaknowles
Created December 7, 2021 03:59
Show Gist options
  • Save davidaknowles/1d45c56b40ffcf573cc6a5743c6e2f25 to your computer and use it in GitHub Desktop.
Save davidaknowles/1d45c56b40ffcf573cc6a5743c6e2f25 to your computer and use it in GitHub Desktop.
import pyro
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from torch.distributions import constraints
data = 2. * torch.Tensor( [-1., -0.5, -0.5, .5, .8, 1.] )
def model(data):
guide_efficacy = pyro.sample('guide_efficacy', dist.Beta(1., 1.).expand([len(data)]).to_event(1) )
gene_essentiality = pyro.sample("gene_essentiality", dist.Normal(0., 5.))
mean = gene_essentiality * guide_efficacy
with pyro.plate("data", len(data)):
obs = pyro.sample("obs", dist.Normal(mean, 1.), obs = data)
def guide(data):
prob = pyro.param("prob", torch.tensor(0.5), constraint=constraints.unit_interval)
z = pyro.sample('assignment', dist.Bernoulli(prob)).long()
ge_mean = pyro.param("ge_mean", torch.ones(2))
ge_scale = pyro.param("ge_scale", torch.ones(2), constraint=constraints.positive)
gene_essentiality = pyro.sample("gene_essentiality", dist.Normal(ge_mean[z], ge_scale[z]))
guide_efficacy_a = pyro.param('guide_efficacy_a', torch.ones([2,len(data)]), constraint=constraints.positive)
guide_efficacy_b = pyro.param('guide_efficacy_b', torch.ones([2,len(data)]), constraint=constraints.positive)
guide_efficacy = pyro.sample("guide_efficacy", dist.Beta(guide_efficacy_a[z,:], guide_efficacy_b[z,:]))
return assignment, gene_essentiality, guide_efficacy
TraceEnum_ELBO().loss(model, config_enumerate(guide, "parallel"), data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment