Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alex-petrenko/937da36b18bf58a830cf49710de81804 to your computer and use it in GitHub Desktop.
Save alex-petrenko/937da36b18bf58a830cf49710de81804 to your computer and use it in GitHub Desktop.
import sys
import math
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
def encoder(observation):
with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
return tf.layers.dense(observation, 32)
def decoder(z, len_sample):
with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
x = tf.layers.dense(z, 32, tf.nn.relu)
x = tf.layers.dense(x, len_sample)
return x
def get_posterior(x0, xt):
# MLP
x_input = tf.concat([x0, xt], axis=1)
hidden = tf.layers.dense(x_input, 32, tf.nn.relu)
# z
mu = tf.layers.dense(hidden, 32)
sigma = tf.layers.dense(hidden, 32, activation=tf.nn.softplus) # softplus is log(exp(x) + 1)
return tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)
def main():
# inference network
# input
len_sample = 40
i0 = tf.placeholder(tf.float32, shape=[None, len_sample])
it = tf.placeholder(tf.float32, shape=[None, len_sample])
# encoding
x0 = encoder(i0)
xt = encoder(it)
posterior = get_posterior(x0, xt)
z_inf = posterior.sample()
# decoder
x_inf = decoder(z_inf, len_sample)
# prior network
lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
# rnn_input should be of shape [time_steps, batch_size, input_size]
time_steps = 60
# input is always x0
rnn_inputs = [x0] * time_steps
rnn_outputs, _ = tf.nn.static_rnn(lstm_cell, rnn_inputs, dtype=tf.float32)
mu_prior = tf.layers.dense(rnn_outputs[-1], 32)
sigma_prior = tf.layers.dense(rnn_outputs[-1], 32, activation=tf.nn.softplus)
prior = tf.contrib.distributions.MultivariateNormalDiag(mu_prior, sigma_prior)
sample_prior = prior.sample()
predicted = decoder(sample_prior, len_sample)
# inference network loss
loss_inf = tf.reduce_mean(tf.losses.mean_squared_error(it, x_inf) + posterior.kl_divergence(prior))
# prior network loss
loss_prior = tf.reduce_mean(prior.kl_divergence(posterior))
loss = loss_inf + loss_prior
opt = tf.train.AdamOptimizer(learning_rate=2e-3).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(4000):
# generate dataset
t = np.linspace(-5 * np.pi, 5 * np.pi, 500)
y = np.sin(t)
noise = np.random.normal(0, 0.02, size=500)
y = y + noise
num_y = len(y)
trajectories = []
for idx in range(len_sample, num_y - time_steps):
i0_data = y[(idx - len_sample):idx]
it_data = y[(idx + time_steps - len_sample):(idx + time_steps)]
trajectories.append((i0_data, it_data))
np.random.shuffle(trajectories)
batch_size = 256
batch_idx = 0
batches = []
while batch_idx < len(trajectories) - batch_size:
minibatch = trajectories[batch_idx:batch_idx + batch_size]
batches.append(minibatch)
batch_idx += batch_size
for batch_i, batch in enumerate(batches):
i0_batch = [item[0] for item in batch]
it_batch = [item[1] for item in batch]
l_val, l_inf, l_pr, _ = sess.run(
[loss, loss_inf, loss_prior, opt],
feed_dict={i0: i0_batch, it: it_batch},
)
if (epoch + 1) % 50 == 0:
print('Epoch #', epoch, 'Loss: ', l_val, 'loss_inf:', l_inf, 'loss_prior:', l_pr)
if (epoch + 1) % 1000 == 0:
# visualize the prediction
trajectory, ground_truth = batches[0][0]
predicted_trajectories = []
for i in range(3):
p_sample = sess.run(predicted, feed_dict={i0: [trajectory]})
predicted_trajectories.append(p_sample[0])
t1 = np.arange(0, len_sample)
t2 = np.arange(time_steps, time_steps + len_sample)
plt.plot(t1, trajectory)
plt.plot(t2, ground_truth)
for i in range(3):
plt.plot(t2, predicted_trajectories[i])
plt.xlabel('time')
plt.ylabel('state')
plt.axis('tight')
plt.show()
print('Exiting...')
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment