Skip to content

Instantly share code, notes, and snippets.

@lsimmons2
Last active February 5, 2019 19:26
Show Gist options
  • Save lsimmons2/efae58a0cc160a97052fa224834676ae to your computer and use it in GitHub Desktop.
Save lsimmons2/efae58a0cc160a97052fa224834676ae to your computer and use it in GitHub Desktop.
def experience_replay(self):
minibatch = random.sample(self.memory, EXPERIENCE_REPLAY_BATCH_SIZE)
minibatch_new_q_values = []
for experience in minibatch:
state, action, reward, next_state, done = experience
state = self._reshape_state_for_net(state)
experience_new_q_values = self.online_network.predict(state)[0]
if done:
q_update = reward
else:
next_state = self._reshape_state_for_net(next_state)
# using online network to SELECT action
online_net_selected_action = np.argmax(self.online_network.predict(next_state))
# using target network to EVALUATE action
target_net_evaluated_q_value = self.target_network\
.predict(next_state)[0][online_net_selected_action]
q_update = reward + GAMMA * target_net_evaluated_q_value
experience_new_q_values[action] = q_update
minibatch_new_q_values.append(experience_new_q_values)
minibatch_states = np.array([e[0] for e in minibatch])
minibatch_new_q_values = np.array(minibatch_new_q_values)
self.online_network.fit(minibatch_states, minibatch_new_q_values, verbose=False, epochs=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment