Skip to content

Instantly share code, notes, and snippets.

@tscholak
Last active April 15, 2017 14:37
Show Gist options
  • Save tscholak/14a042a65eb4e390396c9dab5aefc9f7 to your computer and use it in GitHub Desktop.
Save tscholak/14a042a65eb4e390396c9dab5aefc9f7 to your computer and use it in GitHub Desktop.
`TypeError` when initializing HMC for `ParamMixture` model
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from time import time
import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from edward.models import (Dirichlet, Empirical, InverseGamma, ParamMixture, Normal)
# Generate data
true_mu = np.array([-1., 0., 1.], np.float32) * 10
true_sigmasq = np.array([1.**2, 2.**2, 3.**2], np.float32)
true_pi = np.array([0.2, 0.3, 0.5], np.float32)
N = 10000
K = len(true_mu)
true_c = np.random.choice(np.arange(K), size=N, p=true_pi)
x_val = true_mu[true_c] + np.random.randn(N) * np.sqrt(true_sigmasq[true_c])
# Prior hyperparameters
pi_alpha = 1. + np.zeros(K, dtype=np.float32)
mu_sigma = np.std(true_mu)
sigmasq_alpha = 1.
sigmasq_beta = 2.
# Model
pi = Dirichlet(pi_alpha)
mu = Normal(0., mu_sigma, sample_shape=[K])
sigmasq = InverseGamma(sigmasq_alpha, sigmasq_beta, sample_shape=[K])
x = ParamMixture(pi,
{'mu': mu, 'sigma': tf.sqrt(sigmasq)},
Normal,
sample_shape=N)
c = x.cat
# Inference
T = 5000 # number of samples
qpi = Empirical(params=tf.Variable(tf.ones([T, K]) / K))
qmu = Empirical(params=tf.Variable(tf.zeros([T, K])))
qsigmasq = Empirical(params=tf.Variable(tf.ones([T, K])))
inference = ed.HMC({pi: qpi, mu: qmu, sigmasq: qsigmasq}, data={x: x_val, c: true_c})
inference.initialize(n_print=10, step_size=0.6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment