Skip to content

Instantly share code, notes, and snippets.

@paraschopra
Created November 26, 2018 07:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paraschopra/70a28910ae0d819b4fcd61406ec280db to your computer and use it in GitHub Desktop.
Save paraschopra/70a28910ae0d819b4fcd61406ec280db to your computer and use it in GitHub Desktop.
def model(x_data, y_data):
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment