Skip to content

Instantly share code, notes, and snippets.

@awjuliani
Created December 16, 2016 22:46
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save awjuliani/66e8f477fc1ad000b1314809d8523455 to your computer and use it in GitHub Desktop.
Save awjuliani/66e8f477fc1ad000b1314809d8523455 to your computer and use it in GitHub Desktop.
class AC_Network():
def __init__(self,s_size,a_size,scope,trainer):
with tf.variable_scope(scope):
#Input and visual encoding layers
self.inputs = tf.placeholder(shape=[None,s_size],dtype=tf.float32)
self.imageIn = tf.reshape(self.inputs,shape=[-1,84,84,1])
self.conv1 = slim.conv2d(activation_fn=tf.nn.elu,
inputs=self.imageIn,num_outputs=16,
kernel_size=[8,8],stride=[4,4],padding='VALID')
self.conv2 = slim.conv2d(activation_fn=tf.nn.elu,
inputs=self.conv1,num_outputs=32,
kernel_size=[4,4],stride=[2,2],padding='VALID')
hidden = slim.fully_connected(slim.flatten(self.conv2),256,activation_fn=tf.nn.elu)
#Recurrent network for temporal dependencies
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(256,state_is_tuple=True)
c_init = np.zeros((1, lstm_cell.state_size.c), np.float32)
h_init = np.zeros((1, lstm_cell.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h])
self.state_in = (c_in, h_in)
rnn_in = tf.expand_dims(hidden, [0])
step_size = tf.shape(self.imageIn)[:1]
state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm_cell, rnn_in, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
self.state_out = (lstm_c[:1, :], lstm_h[:1, :])
rnn_out = tf.reshape(lstm_outputs, [-1, 256])
#Output layers for policy and value estimations
self.policy = slim.fully_connected(rnn_out,a_size,
activation_fn=tf.nn.softmax,
weights_initializer=normalized_columns_initializer(0.01),
biases_initializer=None)
self.value = slim.fully_connected(rnn_out,1,
activation_fn=None,
weights_initializer=normalized_columns_initializer(1.0),
biases_initializer=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment