Skip to content

Instantly share code, notes, and snippets.

@awjuliani
Created December 16, 2016 22:56
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save awjuliani/9149588eed921eda593bf20e6f9b7e32 to your computer and use it in GitHub Desktop.
Save awjuliani/9149588eed921eda593bf20e6f9b7e32 to your computer and use it in GitHub Desktop.
class AC_Network():
def __init__(self,s_size,a_size,scope,trainer):
....
....
....
if scope != 'global':
self.actions = tf.placeholder(shape=[None],dtype=tf.int32)
self.actions_onehot = tf.one_hot(self.actions,a_size,dtype=tf.float32)
self.target_v = tf.placeholder(shape=[None],dtype=tf.float32)
self.advantages = tf.placeholder(shape=[None],dtype=tf.float32)
self.responsible_outputs = tf.reduce_sum(self.policy * self.actions_onehot, [1])
#Loss functions
self.value_loss = 0.5 * tf.reduce_sum(tf.square(self.target_v - tf.reshape(self.value,[-1])))
self.entropy = - tf.reduce_sum(self.policy * tf.log(self.policy))
self.policy_loss = -tf.reduce_sum(tf.log(self.responsible_outputs)*self.advantages)
self.loss = 0.5 * self.value_loss + self.policy_loss - self.entropy * 0.01
#Get gradients from local network using local losses
local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
self.gradients = tf.gradients(self.loss,local_vars)
self.var_norms = tf.global_norm(local_vars)
grads,self.grad_norms = tf.clip_by_global_norm(self.gradients,40.0)
#Apply local gradients to global network
global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'global')
self.apply_grads = trainer.apply_gradients(zip(grads,global_vars))
class Worker():
....
....
....
def train(self,global_AC,rollout,sess,gamma,bootstrap_value):
rollout = np.array(rollout)
observations = rollout[:,0]
actions = rollout[:,1]
rewards = rollout[:,2]
next_observations = rollout[:,3]
values = rollout[:,5]
# Here we take the rewards and values from the rollout, and use them to
# generate the advantage and discounted returns.
# The advantage function uses "Generalized Advantage Estimation"
self.rewards_plus = np.asarray(rewards.tolist() + [bootstrap_value])
discounted_rewards = discount(self.rewards_plus,gamma)[:-1]
self.value_plus = np.asarray(values.tolist() + [bootstrap_value])
advantages = rewards + gamma * self.value_plus[1:] - self.value_plus[:-1]
advantages = discount(advantages,gamma)
# Update the global network using gradients from loss
# Generate network statistics to periodically save
rnn_state = self.local_AC.state_init
feed_dict = {self.local_AC.target_v:discounted_rewards,
self.local_AC.inputs:np.vstack(observations),
self.local_AC.actions:actions,
self.local_AC.advantages:advantages,
self.local_AC.state_in[0]:rnn_state[0],
self.local_AC.state_in[1]:rnn_state[1]}
v_l,p_l,e_l,g_n,v_n,_ = sess.run([self.local_AC.value_loss,
self.local_AC.policy_loss,
self.local_AC.entropy,
self.local_AC.grad_norms,
self.local_AC.var_norms,
self.local_AC.apply_grads],
feed_dict=feed_dict)
return v_l / len(rollout),p_l / len(rollout),e_l / len(rollout), g_n,v_n
@trenkvaz
Copy link

class Worker():
....
....
....
deftrain(self,global_AC,rollout,sess,gamma,bootstrap_value)

Why do I need to pass global_AC to a function
It's not used, is it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment