Skip to content

Instantly share code, notes, and snippets.

@d0znpp
Created December 12, 2017 00:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d0znpp/adabe5f7160c1ac3a7088379ad4af746 to your computer and use it in GitHub Desktop.
Save d0znpp/adabe5f7160c1ac3a7088379ad4af746 to your computer and use it in GitHub Desktop.
def create_variables(self):
with tf.name_scope("model_inputs"):
# raw state representation
self.states = tf.placeholder(tf.float32, [None, self.max_layers*4], name="states")
with tf.name_scope("predict_actions"):
# initialize policy network
with tf.variable_scope("policy_network"):
self.policy_outputs = self.policy_network(self.states, self.max_layers)
self.action_scores = tf.identity(self.policy_outputs, name="action_scores")
self.predicted_action = tf.cast(tf.scalar_mul(self.division_rate, self.action_scores), tf.int32, name="predicted_action")
# regularization loss
policy_network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy_network")
# compute loss and gradients
with tf.name_scope("compute_gradients"):
# gradients for selecting action from policy network
self.discounted_rewards = tf.placeholder(tf.float32, (None,), name="discounted_rewards")
with tf.variable_scope("policy_network", reuse=True):
self.logprobs = self.policy_network(self.states, self.max_layers)
# compute policy loss and regularization loss
self.cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logprobs, labels=self.states)
self.pg_loss = tf.reduce_mean(self.cross_entropy_loss)
self.reg_loss = tf.reduce_sum([tf.reduce_sum(tf.square(x)) for x in policy_network_variables])
self.loss = self.pg_loss + self.reg_param * self.reg_loss
#compute gradients
self.gradients = self.optimizer.compute_gradients(self.loss)
# compute policy gradients
for i, (grad, var) in enumerate(self.gradients):
if grad is not None:
self.gradients[i] = (grad * self.discounted_rewards, var)
# training update
with tf.name_scope("train_policy_network"):
# apply gradients to update policy network
self.train_op = self.optimizer.apply_gradients(self.gradients)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment