Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created May 2, 2023 10:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save horoiwa/a2582747eb6450a66fb0bdfed0bf337d to your computer and use it in GitHub Desktop.
Save horoiwa/a2582747eb6450a66fb0bdfed0bf337d to your computer and use it in GitHub Desktop.
def update_policy(self, states, actions):
""" Advantage weighted regression
"""
q1, q2 = self.target_qnet(states, actions)
Q = tf.minimum(q1, q2)
V = self.valuenet(states)
exp_Adv = tf.minimum(tf.exp((Q - V) * self.temperature), 100.0)
with tf.GradientTape() as tape:
dists = self.policy(states)
log_probs = tf.reshape(dists.log_prob(actions), (-1, 1))
loss = tf.reduce_mean(-1 * (exp_Adv * log_probs))
variables = self.policy.trainable_variables
grads = tape.gradient(loss, variables)
self.p_optimizer.apply_gradients(zip(grads, variables))
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment