Skip to content

Instantly share code, notes, and snippets.

@khacanh
Last active February 11, 2019 23:20
Show Gist options
  • Save khacanh/287dbf64163fda470c89e8b992e112e0 to your computer and use it in GitHub Desktop.
Save khacanh/287dbf64163fda470c89e8b992e112e0 to your computer and use it in GitHub Desktop.
Mixture density network with Tensorflow
# From http://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import math
NHIDDEN = 24
STDEV = 0.5
KMIX = 24 # number of mixtures
NOUT = KMIX * 3 # pi, mu, stdev
NSAMPLE = 2400
y_data = np.float32(np.random.uniform(-10.5, 10.5, (1, NSAMPLE))).T
r_data = np.float32(np.random.normal(size=(NSAMPLE,1))) # random noise
x_data = np.float32(np.sin(0.75*y_data)*7.0+y_data*0.5+r_data*1.0)
x = tf.placeholder(dtype=tf.float32, shape=[None,1], name="x")
y = tf.placeholder(dtype=tf.float32, shape=[None,1], name="y")
Wh = tf.Variable(tf.random_normal([1,NHIDDEN], stddev=STDEV, dtype=tf.float32))
bh = tf.Variable(tf.random_normal([1,NHIDDEN], stddev=STDEV, dtype=tf.float32))
Wo = tf.Variable(tf.random_normal([NHIDDEN,NOUT], stddev=STDEV, dtype=tf.float32))
bo = tf.Variable(tf.random_normal([1,NOUT], stddev=STDEV, dtype=tf.float32))
hidden_layer = tf.nn.tanh(tf.matmul(x, Wh) + bh)
output = tf.matmul(hidden_layer,Wo) + bo
def get_mixture_coef(output):
out_pi = tf.placeholder(dtype=tf.float32, shape=[None,KMIX], name="mixparam")
out_sigma = tf.placeholder(dtype=tf.float32, shape=[None,KMIX], name="mixparam")
out_mu = tf.placeholder(dtype=tf.float32, shape=[None,KMIX], name="mixparam")
out_pi, out_sigma, out_mu = tf.split(output, num_or_size_splits=3, axis=1)
max_pi = tf.reduce_max(out_pi, 1, keep_dims=True)
out_pi = tf.subtract(out_pi, max_pi)
out_pi = tf.exp(out_pi)
normalize_pi = tf.reciprocal(tf.reduce_sum(out_pi, 1, keep_dims=True))
out_pi = tf.multiply(normalize_pi, out_pi)
out_sigma = tf.exp(out_sigma)
return out_pi, out_sigma, out_mu
out_pi, out_sigma, out_mu = get_mixture_coef(output)
oneDivSqrtTwoPI = 1 / math.sqrt(2*math.pi)
def tf_normal(y, mu, sigma):
result = tf.subtract(y, mu)
result = tf.multiply(result, tf.reciprocal(sigma))
result = -tf.square(result)/2
return tf.multiply(tf.exp(result),tf.reciprocal(sigma))*oneDivSqrtTwoPI
def get_lossfunc(out_pi, out_sigma, out_mu, y):
result = tf_normal(y, out_mu, out_sigma)
result = tf.multiply(result, out_pi)
result = tf.reduce_sum(result, 1, keep_dims=True)
result = -tf.log(result)
return tf.reduce_mean(result)
lossfunc = get_lossfunc(out_pi, out_sigma, out_mu, y)
train_op = tf.train.AdamOptimizer().minimize(lossfunc)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
NEPOCH = 10000
loss = np.zeros(NEPOCH)
for i in range(NEPOCH):
sess.run(train_op,feed_dict={x: x_data, y: y_data})
loss[i] = sess.run(lossfunc, feed_dict={x: x_data, y: y_data})
# plt.figure(figsize=(8, 8))
# plt.plot(np.arange(100, NEPOCH,1), loss[100:], 'r-')
# plt.show()
x_test = np.float32(np.arange(-15,15,0.1))
NTEST = x_test.size
x_test = x_test.reshape(NTEST,1) # needs to be a matrix, not a vector
def get_pi_idx(x, pdf):
N = pdf.size
accumulate = 0
for i in range(0, N):
accumulate += pdf[i]
if (accumulate >= x):
return i
print 'error with sampling ensemble'
return -1
def generate_ensemble(out_pi, out_mu, out_sigma, M = 10):
NTEST = x_test.size
result = np.random.rand(NTEST, M) # initially random [0, 1]
rn = np.random.randn(NTEST, M) # normal random matrix (0.0, 1.0)
mu = 0
std = 0
idx = 0
# transforms result into random ensembles
for j in range(0, M):
for i in range(0, NTEST):
idx = get_pi_idx(result[i, j], out_pi[i])
mu = out_mu[i, idx]
std = out_sigma[i, idx]
result[i, j] = mu + rn[i, j]*std
return result
out_pi_test, out_sigma_test, out_mu_test = sess.run(get_mixture_coef(output), feed_dict={x: x_test})
y_test = generate_ensemble(out_pi_test, out_mu_test, out_sigma_test)
plt.figure(figsize=(8, 8))
plt.plot(x_data,y_data,'ro', x_test,y_test,'bo',alpha=0.3)
plt.show()
@ashimrijal
Copy link

ashimrijal commented Dec 12, 2018

HI,
This is great that you implemented MDN in tensorflow. I actually have a question on line 59 (tf.reduce_mean(result)). Particularly, I didn't understand why we are actually computing mean of the tensor named "result" here, because we just have that tensor of only one entry from lines 57 (where we summed all elements of that tensor) and took negative logarithm on 58 (or am I missing something here?)?

I will highly appreciate the help.

Thanks,
Ashim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment