Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Last active January 14, 2020 12:41
Show Gist options
  • Save pythonlessons/bee25a75c86d274cb16acf95784a81e8 to your computer and use it in GitHub Desktop.
Save pythonlessons/bee25a75c86d274cb16acf95784a81e8 to your computer and use it in GitHub Desktop.
02_CartPole-reinforcement-learning_DDQN
def replay(self):
if len(self.memory) < self.train_start:
return
# Randomly sample minibatch from the memory
minibatch = random.sample(self.memory, min(self.batch_size, self.batch_size))
state = np.zeros((self.batch_size, self.state_size))
next_state = np.zeros((self.batch_size, self.state_size))
action, reward, done = [], [], []
# do this before prediction
# for speedup, this could be done on the tensor level
# but easier to understand using a loop
for i in range(self.batch_size):
state[i] = minibatch[i][0]
action.append(minibatch[i][1])
reward.append(minibatch[i][2])
next_state[i] = minibatch[i][3]
done.append(minibatch[i][4])
# do batch prediction to save speed
target = self.model.predict(state)
target_next = self.model.predict(next_state)
target_val = self.target_model.predict(next_state)
for i in range(len(minibatch)):
# correction on the Q value for the action used
if done[i]:
target[i][action[i]] = reward[i]
else:
if self.ddqn: # Double - DQN
# current Q Network selects the action
# a'_max = argmax_a' Q(s', a')
a = np.argmax(target_next[i])
# target Q Network evaluates the action
# Q_max = Q_target(s', a'_max)
target[i][action[i]] = reward[i] + self.gamma * (target_val[i][a])
else: # Standard - DQN
# DQN chooses the max Q value among next actions
# selection and evaluation of action is on the target Q Network
# Q_max = max_a' Q_target(s', a')
target[i][action[i]] = reward[i] + self.gamma * (np.amax(target_next[i]))
# Train the Neural Network with batches
self.model.fit(state, target, batch_size=self.batch_size, verbose=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment