Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created May 10, 2020 07:14
Show Gist options
  • Save horoiwa/93510272dd6dd5f59aa2c17d18114491 to your computer and use it in GitHub Desktop.
Save horoiwa/93510272dd6dd5f59aa2c17d18114491 to your computer and use it in GitHub Desktop.
DQNロス
class DQNAgent:
""""略:その他のメソッド""""
def update_qnetwork(self):
(states, actions, rewards,
next_states, dones) = self.get_minibatch(self.BATCH_SIZE)
next_Qs = np.max(self.target_network.predict(next_states), axis=1)
target_values = [reward + self.gamma * next_q if not done else reward
for reward, next_q, done
in zip(rewards, next_Qs, dones)]
self.q_network.update(np.array(states), np.array(actions),
np.array(target_values))
class QNetwork(tf.keras.Model):
""""略:その他のメソッド""""
def update(self, states, selected_actions, target_values):
with tf.GradientTape() as tape:
selected_actions_onehot = tf.one_hot(selected_actions,
self.action_space)
selected_action_values = tf.reduce_sum(
self(states) * selected_actions_onehot, axis=1)
loss = tf.reduce_mean(
tf.square(target_values - selected_action_values))
variables = self.trainable_variables
gradients = tape.gradient(loss, variables)
self.optimizer.apply_gradients(zip(gradients, variables))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment