Skip to content

Instantly share code, notes, and snippets.

@sol0invictus
Created May 29, 2020 18:45
Show Gist options
  • Save sol0invictus/f03546495d7572060c8426558f18ecfe to your computer and use it in GitHub Desktop.
Save sol0invictus/f03546495d7572060c8426558f18ecfe to your computer and use it in GitHub Desktop.
X,A,R,X2,D = replay_buffer.sample(batch_size)
X = np.asarray(X,dtype=np.float32)
A = np.asarray(A,dtype=np.float32)
R = np.asarray(R,dtype=np.float32)
X2 = np.asarray(X2,dtype=np.float32)
D = np.asarray(D,dtype=np.float32)
Xten=tf.convert_to_tensor(X)
#Actor optimization
with tf.GradientTape() as tape2:
Aprime = action_max * mu.predict_on_batch(X)
temp = tf.keras.layers.concatenate([Xten,Aprime],axis=1)
Q = q_mu.predict_on_batch(temp)
mu_loss = -tf.reduce_mean(Q)
grads_mu = tape2.gradient(mu_loss,mu.trainable_variables)
mu_losses.append(mu_loss)
mu_optimizer.apply_gradients(zip(grads_mu, mu.trainable_variables))
#Critic Optimization
with tf.GradientTape() as tape:
next_a = action_max * mu_target.predict_on_batch(X2)
temp = np.concatenate((X2,next_a),axis=1)
q_target = R + gamma * (1 - D) * q_mu_target.predict_on_batch(temp)
temp2 = np.concatenate((X,A),axis=1)
qvals = q_mu.predict_on_batch(temp2)
q_loss = tf.reduce_mean((qvals - q_target)**2)
grads_q = tape.gradient(q_loss,q_mu.trainable_variables)
q_optimizer.apply_gradients(zip(grads_q, q_mu.trainable_variables))
q_losses.append(q_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment