Created
April 24, 2017 16:20
-
-
Save lemonzi/0b1be3ba538a8ce3c0501114f9c4ec7e to your computer and use it in GitHub Desktop.
Unit testing the A3C algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.contrib import distributions | |
# These are your input features, a random tensor. | |
# There are 128 examples per batch, 1000 time-steps, one observation. | |
# If your code doesn't use batches, remove the first dimension. | |
features = tf.random_normal([128, 1000, 1]) | |
# Put your networks here | |
# policy_mu, policy_sigma = ... | |
# This is how you can sample actions from the policy in Tensorflow. | |
# The stop_gradient() is very important! | |
sample = tf.stop_gradient( | |
distributions.Normal(policy_mu, tf.sqrt(policy_sigma)).sample()) | |
# Learning a constant. | |
# target = 3.14159 | |
# Learning the identity function. | |
# target = features | |
# Learning a non-linear function. Requires non-linear activations. | |
# target = tf.sign(features) | |
# Learning a time-dependent function. Requires an RNN. | |
target = discount(features, 0.5, axis=1) - features | |
reward = tf.stop_gradient(-tf.reduce_sum(tf.abs(sample - target), axis=2)) | |
# Compute the loss here using features, policy, sample, and reward. | |
# loss = ... | |
summary_op = tf.summary.merge([ | |
tf.summary.scalar('Loss', loss), | |
tf.summary.histogram('Mu', policy_mu), | |
tf.summary.histogram('Sigma', policy_sigma), | |
tf.summary.histogram('Sample', sample), | |
tf.summary.histogram('Reward', reward)]) | |
# You will need to adjust the learning rate. | |
optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-2) | |
train_op = optimizer.minimize(loss) | |
# TensorBoard summaries will be written to disk. | |
writer = tf.summary.FileWriter(sys.argv[1]) | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
for i in xrange(100000): | |
if i % 10 == 0: | |
_, summaries = sess.run([train_op, summary_op]) | |
writer.add_summary(summaries, i) | |
else: | |
sess.run(train_op) | |
# You will need this function, which is written in Python in the | |
# reference code. You can use it inside Tensorflow, with tensors, using | |
# tensorify: https://github.com/lemonzi/tensorify. | |
@tensorify.tensorflow_op(tf.float32, shape=lambda shapes: shapes[0]) | |
def discount(x, gamma, axis=0): | |
y = scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=axis)[::-1] | |
return y.astype(np.float32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment