Skip to content

Instantly share code, notes, and snippets.

@demodw
Last active December 14, 2017 22:58
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save demodw/a5acafb4025d53f4d3f80f559dcce775 to your computer and use it in GitHub Desktop.
Save demodw/a5acafb4025d53f4d3f80f559dcce775 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from edward.models import Normal
from edward.stats import Multinomial, norm
class HierarchicalSoftmax:
def __init__(self, inv_link=tf.nn.softmax, prior_std=3.0):
self.inv_link = inv_link
self.prior_std = prior_std
def log_prob(self, xs, zs):
x, y = xs['x'], xs['y']
w, b = zs['w'], zs['b']
log_prior = 1
# Calculate theta for each outcome
log_lik = 0
for i in range(40):
# For each data point...
theta = [0, 0, 0]
for j in range(1, 3):
# Intercept...
theta[j] += b[j-1]
for k in range(5):
# For each feature...
theta[j] += x[i, j] * w[j-1, k]
log_lik += tf.reduce_sum(Multinomial.logpmf(y, p=tf.nn.softmax(theta)))
return log_lik + log_prior
def build_toy_dataset(N, D, C):
D = 5
x = np.linspace(-3, 3, num=N*D)
y = np.random.multinomial(n=100, pvals=[1.0/C]*C, size=N)
x = (x - 4.0) / 4.0
x = x.reshape((N, D))
return x, y
#ed.set_seed(42)
C = 3
N = 40 # num data points
D = 1 # num features
x_train, y_train = build_toy_dataset(N, D, C)
model = HierarchicalSoftmax()
qw_mu = tf.Variable(tf.random_normal([C-1, D]))
qw_sigma = tf.nn.softplus(tf.Variable(tf.random_normal([D])))
qb_mu = tf.Variable(tf.random_normal([C-1]))
qb_sigma = tf.nn.softplus(tf.Variable(tf.random_normal([])))
qw = Normal(mu=qw_mu, sigma=qw_sigma)
qb = Normal(mu=qb_mu, sigma=qb_sigma)
data = {'x': x_train, 'y': y_train}
inference = ed.MFVI({'w': qw, 'b': qb}, data, model)
inference.initialize(n_print=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment