Last active
January 14, 2020 12:41
-
-
Save pythonlessons/bee25a75c86d274cb16acf95784a81e8 to your computer and use it in GitHub Desktop.
02_CartPole-reinforcement-learning_DDQN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def replay(self): | |
if len(self.memory) < self.train_start: | |
return | |
# Randomly sample minibatch from the memory | |
minibatch = random.sample(self.memory, min(self.batch_size, self.batch_size)) | |
state = np.zeros((self.batch_size, self.state_size)) | |
next_state = np.zeros((self.batch_size, self.state_size)) | |
action, reward, done = [], [], [] | |
# do this before prediction | |
# for speedup, this could be done on the tensor level | |
# but easier to understand using a loop | |
for i in range(self.batch_size): | |
state[i] = minibatch[i][0] | |
action.append(minibatch[i][1]) | |
reward.append(minibatch[i][2]) | |
next_state[i] = minibatch[i][3] | |
done.append(minibatch[i][4]) | |
# do batch prediction to save speed | |
target = self.model.predict(state) | |
target_next = self.model.predict(next_state) | |
target_val = self.target_model.predict(next_state) | |
for i in range(len(minibatch)): | |
# correction on the Q value for the action used | |
if done[i]: | |
target[i][action[i]] = reward[i] | |
else: | |
if self.ddqn: # Double - DQN | |
# current Q Network selects the action | |
# a'_max = argmax_a' Q(s', a') | |
a = np.argmax(target_next[i]) | |
# target Q Network evaluates the action | |
# Q_max = Q_target(s', a'_max) | |
target[i][action[i]] = reward[i] + self.gamma * (target_val[i][a]) | |
else: # Standard - DQN | |
# DQN chooses the max Q value among next actions | |
# selection and evaluation of action is on the target Q Network | |
# Q_max = max_a' Q_target(s', a') | |
target[i][action[i]] = reward[i] + self.gamma * (np.amax(target_next[i])) | |
# Train the Neural Network with batches | |
self.model.fit(state, target, batch_size=self.batch_size, verbose=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment