Skip to content

Instantly share code, notes, and snippets.

@unaoya
Last active January 5, 2018 05:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unaoya/53d988d72b7ce1ce964dba87f1d066c1 to your computer and use it in GitHub Desktop.
Save unaoya/53d988d72b7ce1ce964dba87f1d066c1 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
import edward as ed
from edward.models import Bernoulli, Beta, PointMass
ed.set_seed(42)
# DATA
n = 4
x_train = np.array([0, 1, 1, 1]).reshape((n,1))
# MODEL
alpha = 1.0
beta = 1.0
p = Beta(concentration1=tf.Variable([alpha]),concentration0=tf.Variable([beta]))
x = Bernoulli(probs=p, sample_shape=(n))
# INFERENCE by MAP
qp = PointMass(params=tf.Variable([0.1]))
inference = ed.MAP({p: qp}, data={x: x_train})
inference.run()
# CRITICISM
sess = ed.get_session()
print(sess.run(qp.mean()))
print(sess.run(qp.params))
print(qp.eval())
x_post = ed.copy(x, {p: qp})
print(ed.evaluate('binary_accuracy', data={x_post: x_train}))
# INFERENCE by KLqp
alpha = 1.0
beta = 1.0
qp = Beta(concentration1=tf.Variable([alpha]),concentration0=tf.Variable([beta]))
inference = ed.KLqp({p: qp}, data={x: x_train})
inference.run()
# CRITICISM
sess = ed.get_session()
print(sess.run(qp.mean()))
print(sess.run(qp.concentration0), sess.run(qp.concentration1))
print(qp.eval())
x_post = ed.copy(x, {p: qp})
print(ed.evaluate('binary_accuracy', data={x_post: x_train}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment