Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import sys
import math
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
def encoder(observation):
with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
return tf.layers.dense(observation, 32)
def decoder(z, len_sample):
with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
x = tf.layers.dense(z, 32, tf.nn.relu)
x = tf.layers.dense(x, len_sample)
return x
def get_posterior(x0, xt):
# MLP
x_input = tf.concat([x0, xt], axis=1)
hidden = tf.layers.dense(x_input, 32, tf.nn.relu)
# z
mu = tf.layers.dense(hidden, 32)
sigma = tf.layers.dense(hidden, 32, activation=tf.nn.softplus) # softplus is log(exp(x) + 1)
return tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)
def main():
# inference network
# input
len_sample = 40
i0 = tf.placeholder(tf.float32, shape=[None, len_sample])
it = tf.placeholder(tf.float32, shape=[None, len_sample])
# encoding
x0 = encoder(i0)
xt = encoder(it)
posterior = get_posterior(x0, xt)
z_inf = posterior.sample()
# decoder
x_inf = decoder(z_inf, len_sample)
# prior network
lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
# rnn_input should be of shape [time_steps, batch_size, input_size]
time_steps = 60
# input is always x0
rnn_inputs = [x0] * time_steps
rnn_outputs, _ = tf.nn.static_rnn(lstm_cell, rnn_inputs, dtype=tf.float32)
mu_prior = tf.layers.dense(rnn_outputs[-1], 32)
sigma_prior = tf.layers.dense(rnn_outputs[-1], 32, activation=tf.nn.softplus)
prior = tf.contrib.distributions.MultivariateNormalDiag(mu_prior, sigma_prior)
sample_prior = prior.sample()
predicted = decoder(sample_prior, len_sample)
# inference network loss
loss_inf = tf.reduce_mean(tf.losses.mean_squared_error(it, x_inf) + posterior.kl_divergence(prior))
# prior network loss
loss_prior = tf.reduce_mean(prior.kl_divergence(posterior))
loss = loss_inf + loss_prior
opt = tf.train.AdamOptimizer(learning_rate=2e-3).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(4000):
# generate dataset
t = np.linspace(-5 * np.pi, 5 * np.pi, 500)
y = np.sin(t)
noise = np.random.normal(0, 0.02, size=500)
y = y + noise
num_y = len(y)
trajectories = []
for idx in range(len_sample, num_y - time_steps):
i0_data = y[(idx - len_sample):idx]
it_data = y[(idx + time_steps - len_sample):(idx + time_steps)]
trajectories.append((i0_data, it_data))
np.random.shuffle(trajectories)
batch_size = 256
batch_idx = 0
batches = []
while batch_idx < len(trajectories) - batch_size:
minibatch = trajectories[batch_idx:batch_idx + batch_size]
batches.append(minibatch)
batch_idx += batch_size
for batch_i, batch in enumerate(batches):
i0_batch = [item[0] for item in batch]
it_batch = [item[1] for item in batch]
l_val, l_inf, l_pr, _ = sess.run(
[loss, loss_inf, loss_prior, opt],
feed_dict={i0: i0_batch, it: it_batch},
)
if (epoch + 1) % 50 == 0:
print('Epoch #', epoch, 'Loss: ', l_val, 'loss_inf:', l_inf, 'loss_prior:', l_pr)
if (epoch + 1) % 1000 == 0:
# visualize the prediction
trajectory, ground_truth = batches[0][0]
predicted_trajectories = []
for i in range(3):
p_sample = sess.run(predicted, feed_dict={i0: [trajectory]})
predicted_trajectories.append(p_sample[0])
t1 = np.arange(0, len_sample)
t2 = np.arange(time_steps, time_steps + len_sample)
plt.plot(t1, trajectory)
plt.plot(t2, ground_truth)
for i in range(3):
plt.plot(t2, predicted_trajectories[i])
plt.xlabel('time')
plt.ylabel('state')
plt.axis('tight')
plt.show()
print('Exiting...')
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment