Skip to content

Instantly share code, notes, and snippets.

@lemonzi
Created April 24, 2017 16:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lemonzi/0b1be3ba538a8ce3c0501114f9c4ec7e to your computer and use it in GitHub Desktop.
Save lemonzi/0b1be3ba538a8ce3c0501114f9c4ec7e to your computer and use it in GitHub Desktop.
Unit testing the A3C algorithm
import tensorflow as tf
from tensorflow.contrib import distributions
# These are your input features, a random tensor.
# There are 128 examples per batch, 1000 time-steps, one observation.
# If your code doesn't use batches, remove the first dimension.
features = tf.random_normal([128, 1000, 1])
# Put your networks here
# policy_mu, policy_sigma = ...
# This is how you can sample actions from the policy in Tensorflow.
# The stop_gradient() is very important!
sample = tf.stop_gradient(
distributions.Normal(policy_mu, tf.sqrt(policy_sigma)).sample())
# Learning a constant.
# target = 3.14159
# Learning the identity function.
# target = features
# Learning a non-linear function. Requires non-linear activations.
# target = tf.sign(features)
# Learning a time-dependent function. Requires an RNN.
target = discount(features, 0.5, axis=1) - features
reward = tf.stop_gradient(-tf.reduce_sum(tf.abs(sample - target), axis=2))
# Compute the loss here using features, policy, sample, and reward.
# loss = ...
summary_op = tf.summary.merge([
tf.summary.scalar('Loss', loss),
tf.summary.histogram('Mu', policy_mu),
tf.summary.histogram('Sigma', policy_sigma),
tf.summary.histogram('Sample', sample),
tf.summary.histogram('Reward', reward)])
# You will need to adjust the learning rate.
optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-2)
train_op = optimizer.minimize(loss)
# TensorBoard summaries will be written to disk.
writer = tf.summary.FileWriter(sys.argv[1])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in xrange(100000):
if i % 10 == 0:
_, summaries = sess.run([train_op, summary_op])
writer.add_summary(summaries, i)
else:
sess.run(train_op)
# You will need this function, which is written in Python in the
# reference code. You can use it inside Tensorflow, with tensors, using
# tensorify: https://github.com/lemonzi/tensorify.
@tensorify.tensorflow_op(tf.float32, shape=lambda shapes: shapes[0])
def discount(x, gamma, axis=0):
y = scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=axis)[::-1]
return y.astype(np.float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment