Last active
July 14, 2018 10:42
-
-
Save martinholub/083d392d713e515fb5b96379a77670e0 to your computer and use it in GitHub Desktop.
Double Dueling Deep Q-Learning Network
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.callbacks import Callback as KerasCallback, CallbackList as KerasCallbackList | |
from keras.callbacks import TensorBoard | |
from keras.optimizers import Adam, RMSprop | |
import keras.backend as K | |
class SubTensorBoard(TensorBoard): | |
"""Subclassing of tensorboard to log and visualize custom metrics and others | |
Note that for this to work, you will have to define a way how to handle `on_episode_end` | |
calls. | |
Check https://github.com/martinholub/demos-blogs-examples/blob/master/rl-gym/atari/callbacks.py | |
for working implementation. | |
""" | |
def __init__(self, *args, **kwargs): | |
super(SubTensorBoard, self).__init__(*args, **kwargs) | |
def lr_getter(self): | |
decay = self.model.optimizer.decay | |
lr = self.model.optimizer.lr | |
iters = self.model.optimizer.iterations # only this should not be const | |
if isinstance(self.model.optimizer, (Adam, )): | |
# Get vals | |
beta_1 = self.model.optimizer.beta_1 | |
beta_2 = self.model.optimizer.beta_2 | |
# calculate | |
lr = lr * (1. / (1. + decay * K.cast(iters, K.dtype(decay)))) | |
t = K.cast(iters, K.floatx()) + 1 | |
lr_t = lr * (K.sqrt(1. - K.pow(beta_2, t)) / (1. - K.pow(beta_1, t))) | |
return np.float32(K.eval(lr_t)) | |
elif isinstance(self.model.optimizer, (RMSprop, )): | |
lr = lr * (1. / (1. + decay * K.cast(iters, K.dtype(decay)))) | |
return np.float32(K.eval(lr)) | |
else: | |
lr = lr * (1. / (1. + decay * K.cast(iters, K.dtype(decay)))) | |
return np.float32(K.eval(lr)) | |
def on_episode_end(self, episode, logs = {}): | |
logs.update({"lr": self.lr_getter()}) | |
super(SubTensorBoard, self).on_epoch_end(episode, logs) | |
class UpdateLossMask(KerasCallback): | |
"""Update A Loss Mask before each training | |
# Alternative is to define loss as a new layer | |
""" | |
def __init__(self, error): | |
# Error is a function that takes mask and compiles to other function taking y_true, y_pred | |
self.error = error | |
def on_batch_begin(self, step, logs={}): | |
new_loss = self.error(self.model.input[1]) | |
if not isinstance(self.model.loss, list): new_loss = [new_loss] | |
self.model.loss = new_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import backend as K | |
import tensorflow as tf | |
import numpy as np | |
# Losses | |
def huber_loss(y_true, y_pred, clip_value = 1.0): | |
# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and | |
# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b | |
# for details. | |
assert clip_value > 0. | |
x = y_true - y_pred | |
if np.isinf(clip_value): | |
# Spacial case for infinity since Tensorflow does have problems | |
# if we compare `K.abs(x) < np.inf`. | |
return .5 * K.square(x) | |
condition = K.abs(x) < clip_value | |
squared_loss = .5 * K.square(x) | |
linear_loss = clip_value * (K.abs(x) - .5 * clip_value) | |
if hasattr(tf, 'select'): | |
return tf.select(condition, squared_loss, linear_loss) # condition, true, false | |
else: | |
return tf.where(condition, squared_loss, linear_loss) # condition, true, false | |
def clipped_error(y_true, y_pred): | |
return K.mean(huber_loss(y_true, y_pred), axis=-1) | |
def clipped_masked_error(mask): | |
#Update this one with callbacks | |
def clipped_error(y_true, y_pred): | |
loss = huber_loss(y_true, y_pred) | |
loss *= mask | |
# loss = K.dot(loss, mask) | |
return K.sum(loss, axis = -1) | |
return clipped_error |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import backend as K | |
# Metrics | |
def mean_q(y_true, y_pred): | |
return K.mean(K.max(y_pred, axis=1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _construct_q_network(self): | |
# Mask that allows updating of only action that was observed | |
mask_input = Input((self.action_size, ), name = 'mask') | |
# Preprocess data on input, allows storing as uint8 | |
frames_input = Input(self.img_size + (self.num_frames, ), name = 'frames') | |
# Scale by 142 instead of 255, because for BreakOut the max val is 142 | |
x = (Lambda(lambda x: x / 142.0)(frames_input)) | |
x = (Convolution2D(filters = 32, kernel_size = (8, 8), strides = (4, 4), | |
# input_shape = self.img_size + (self.num_frames, ), | |
kernel_regularizer = l2(0.1), | |
kernel_initializer = 'he_normal'))(x) | |
x = (Activation('relu'))(x) | |
x = (Convolution2D(filters = 64, kernel_size = (4, 4), strides = (2, 2), | |
kernel_regularizer = l2(0.1), | |
kernel_initializer = 'he_normal'))(x) | |
x = (Activation('relu'))(x) | |
x = (Convolution2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), | |
kernel_regularizer = l2(0.01)))(x) | |
x = (Activation('relu'))(x) | |
flatten = (Flatten())(x) | |
# Dueling DQN -- decompose output to Advantage and Value parts | |
# V(s): how good it is to be in any given state. | |
# A(a): how much better taking a certain action would be compared to the others | |
fc1 = Dense(units = 512, activation = None, kernel_regularizer = l2(0.1), | |
kernel_initializer = 'he_normal')(flatten) | |
advantage=Dense(self.action_size, activation = None, | |
kernel_regularizer = l2(0.1),kernel_initializer = 'he_normal')(fc1) | |
fc2 = Dense(units = 512, activation = None, kernel_regularizer = l2(0.01))(flatten) | |
value = Dense(1, kernel_regularizer = l2(0.01))(fc2) | |
# dueling_type == 'avg' | |
# Q(s,a;theta) = V(s;theta) + (A(s,a;theta)-Avg_a(A(s,a;theta))) | |
policy = Lambda(lambda x: x[0]-K.mean(x[0])+x[1], | |
output_shape = (self.action_size, ))([advantage, value]) | |
filtered_policy = multiply([policy, mask_input]) | |
self.model = Model(inputs = [frames_input, mask_input], outputs = [filtered_policy]) | |
# Create identical copy of model, make sure they dont point to same object | |
config = self.model.get_config() | |
self.target_model = Model.from_config(config) | |
self.target_update() # Assure weights are identical. | |
losses = [clipped_masked_error(mask_input)] # losses = ["MSE"] | |
metrics = ["mae", mean_q] | |
optimizer = RMSprop(lr = self.learn_rate, | |
epsilon = 0.00, | |
rho = 0.99, | |
decay = 1e-6, | |
clipnorm = 1.) | |
self.model.compile( loss = losses, | |
optimizer = optimizer, | |
metrics = metrics) | |
# Loss, optimizer and metrics just dummy as never trained | |
self.target_model.compile( loss = 'MSE', | |
optimizer = Adam(), | |
metrics = []) | |
print(self.model.summary()) | |
print("Successfully constructed networks.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def replay(self, discount, logger): | |
if isinstance(self.memory, (Memory, )): | |
## Prioretized memory style | |
minibatch, idxs, is_weights = self.memory.sample(self.replay_size) | |
states, q_targets, errors, mask = self.predict_batch(minibatch, discount) | |
self.memory.batch_update(idxs, errors) | |
# mock validation to enable TensorBoard callback visualizations | |
history = self.model.fit([states, mask], q_targets, | |
batch_size = len(minibatch), | |
validation_split = 0.2, verbose = 0) | |
else: | |
## Simple Style Draw random minibatch sample from memory | |
minibatch=random.sample(self.memory, min(len(self.memory), self.replay_size)) | |
states, q_targets, errors, mask = self.predict_batch(minibatch, discount) | |
#state is x, q_target is y | |
# mock validation to enable TensorBoard callback visualizations | |
history = self.model.fit([states, mask], q_targets, batch_size = len(minibatch), | |
validation_split = 0.2, verbose=0) | |
metrics = [] | |
for met in self.model.metrics_names: | |
metrics.extend(history.history[met]) | |
return np.array(metrics), history.validation_data | |
def predict_batch(self, minibatch, discount): | |
"""Predict on batch and return target that model will try to fit | |
Briefly, we try to make the network (`model`) predict its own output. This can become | |
unstable as the sought value is also being changed while we try to | |
approach it. To mitigate this, we keep a target network (`target_model`), | |
that we update only occasionally. | |
Additionaly, we also predict on old states such that we can compute | |
`error` that gives a way how to increase importnace of samples that | |
our model is the worst at predicting (largest error). | |
https://arxiv.org/pdf/1509.06461.pdf | |
https://medium.com/@awjuliani/simple-reinforcement-learning-with-tensorflow-part-4-deep-q-networks-and-beyond-8438a3e2b8df | |
""" | |
# Must work also on batch size 1 | |
batch_size = len(minibatch) | |
# Split batch to parts (s,a,r,s',d) | |
# minibatch = np.reshape(minibatch, (batch_size, 5)) | |
# states, actions, rewards, next_states, dones = np.split(minibatch, 5, axis = 1) | |
states = [] | |
rewards = [] | |
actions = [] | |
next_states = [] | |
dones = [] | |
for _, (state, action, reward, next_state, done) in enumerate(minibatch): | |
states.append(state) | |
actions.append(action) | |
rewards.append(reward) | |
next_states.append(next_state) | |
dones.append(done) | |
## DOUBLE DQN | |
# Use primary network to choose an action | |
all_one_mask = np.ones((len(actions), ) + (self.action_size, ), | |
dtype = np.uint8) | |
q_nexts = self.model.predict_on_batch([np.reshape(next_states, | |
(-1,*self.img_size,self.num_frames)), | |
all_one_mask]) | |
next_actions = np.argmax(q_nexts, axis=1) | |
# Use target network to generate q value for that action | |
q_targets = self.target_model.predict_on_batch([np.reshape(next_states, | |
(-1,*self.img_size,self.num_frames)), | |
all_one_mask]) | |
# predict the future discounted reward. target == reward if done | |
targets = rewards + \ | |
discount * np.invert(dones).astype(np.float32) * \ | |
q_targets[range(batch_size), next_actions] | |
# Update only actions for which we have observation | |
# This is simultanously implemented on model level. Should not be an issue | |
# In future remove from here. | |
q_targets[range(batch_size), actions] = targets # update q to future | |
# Need this one for error term for memory update | |
q_olds = self.model.predict_on_batch([np.reshape(states, | |
(-1,*self.img_size,self.num_frames)), | |
all_one_mask]) | |
targets_old = q_olds[range(batch_size), actions] # pull out old value of q_hat | |
# Get error for updating priorities in the memory | |
errors = (abs(targets_old - targets)) | |
#Get mask to update only q_value that was observed: | |
mask = np.zeros((batch_size, self.action_size), dtype = np.uint8) | |
mask[range(batch_size), actions] = 1 | |
# Reshape as necessary | |
states = np.stack(states, axis = 0) # creates new axis | |
# q_targets = np.concatenate(q_targets, axis = 0) | |
return (states, q_targets, errors, mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment