Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created July 20, 2019 10:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/c2173fe1b6dde3a6fb4f932e65d6f761 to your computer and use it in GitHub Desktop.
Save NMZivkovic/c2173fe1b6dde3a6fb4f932e65d6f761 to your computer and use it in GitHub Desktop.
def retrain(self, batch_size):
minibatch = random.sample(self.expirience_replay, batch_size)
for state, action, reward, next_state, terminated in minibatch:
state = np.expand_dims(np.asarray(state).astype(np.float64), axis=0)
next_state = np.expand_dims(np.asarray(next_state).astype(np.float64), axis=0)
target = self.q_network.predict(state)
if terminated:
target[0][action] = reward
else:
t = self.target_network.predict(next_state)
target[0][action] = reward + self.gamma * np.amax(t)
self.q_network.fit(state, target, epochs=1, verbose=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment