Last active
July 7, 2017 18:44
-
-
Save matthewfeickert/255330b8d56a5d4f0dc64328a6143d97 to your computer and use it in GitHub Desktop.
Edward Shape Problem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import edward as ed | |
# specific modules | |
from edward.models import Normal | |
def sample_model(model, n_samples): | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
samples = sess.run(model.sample([n_samples])) | |
return samples | |
# want to fit, so need to use variables | |
mean = tf.Variable(3.0) | |
std = tf.Variable(1.0) | |
N = 10 | |
x = Normal(loc=mean, scale=std) | |
samples = sample_model(x, N) | |
print("x is a {0} with shape {1}".format(type(x), x.get_shape())) | |
print("\nsamples is a {0} with shape {1}".format(type(samples), samples.shape)) | |
# fails as x and samples don't have the same shape | |
#mle = ed.MAP({}, data={x: samples}) | |
# Alternative | |
x = Normal(loc=mean*tf.ones(N), scale=std*tf.ones(N)) | |
samples = sample_model(x, N) | |
print("\nx is a {0} with shape {1}".format(type(x), x.get_shape())) | |
print("\nsamples is a {0} with shape {1}".format(type(samples), samples.shape)) | |
# fails as x and samples don't have the same shape | |
#mle = ed.MAP({}, data={x: samples}) | |
# works but is hugely inefficient, as only using 1 row of a N x N tensor | |
mle = ed.MAP({}, data={x: samples[0]}) | |
mle.run() | |
sess = ed.get_session() | |
print(sess.run(mean)) | |
# As ed.models.Normal inherits from tf.contrib.distributions.Normal the results are the same with pure TF | |
x = tf.contrib.distributions.Normal(loc=mean*tf.ones(N), scale=std*tf.ones(N)) | |
samples = sample_model(x, N) | |
print("\nx is a {0} with shape {1}".format(type(x), x.event_shape)) | |
print("samples is a {0} with shape {1}".format(type(samples), samples.shape)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment