Created
February 12, 2019 09:30
-
-
Save kschoos/860039fc222fb4c78779e718d45b0418 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import signal | |
import sys | |
from keras import losses | |
from keras.callbacks import TensorBoard, CallbackList | |
from keras.engine.saving import load_model | |
from keras.initializers import Zeros, VarianceScaling | |
from keras.optimizers import Adam | |
from keras.layers import Input, Permute, Convolution2D, Activation, Flatten, Dense, Layer | |
from keras import Model | |
from collections import deque | |
import numpy as np | |
import os | |
import pickle | |
import tensorflow as tf | |
from learningAgents import ReinforcementAgent | |
from game import Actions | |
class TensorBoardWrap(TensorBoard): | |
def __init__(self, generator, **args): | |
TensorBoard.__init__(self, **args) | |
self.generator = generator | |
def on_epoch_end(self, epoch, logs=None): | |
sample_weights = [np.ones((self.batch_size,)) for _ in self.model.sample_weights] | |
training_data = self.generator() | |
self.validation_data = (training_data + sample_weights) | |
TensorBoard.on_epoch_end(self, epoch, logs) | |
class NewDQNAgent(ReinforcementAgent): | |
def huber_loss(self, y_true, y_pred): | |
''' | |
Design Huber Loss according to wikipedia: | |
L(e) = 1/2 e^2 if |e| <= d, else d(|e| - 1/2d) | |
''' | |
error = tf.math.subtract(y_true, y_pred) | |
abs_error = tf.math.abs(error) | |
quadratic = tf.math.minimum(abs_error, self.huber_delta) | |
linear = tf.math.subtract(abs_error, quadratic) | |
losses = tf.math.add( | |
tf.math.multiply( | |
tf.constant(0.5, | |
dtype=quadratic.dtype), | |
tf.math.multiply(quadratic, quadratic)), | |
tf.math.multiply( | |
self.huber_delta, | |
linear | |
)) | |
return losses | |
def generate_filenames(self): | |
self.model_file = self.filename_generator("model", "h5") | |
self.parameters_file = self.filename_generator("params", "pkl") | |
self.memory_file = self.filename_generator("memory", "pkl") | |
self.version_file = self.path + "v.pkl" | |
def filename_generator(self, filename, format): | |
return lambda v: self.path + filename + "_{}.{}".format(v, format) | |
def setup_filesystem(self, remote, layoutName, saveFile): | |
folder = "" | |
if remote: | |
folder = "data" | |
else: | |
folder = "localdata" | |
self.path = "/home/skusku/" + folder +"/machinelearning/save_states/" + layoutName + "/" + saveFile + "/" | |
self.log_dir = self.path + "logs/" | |
self.generate_filenames() | |
if not os.path.exists(self.path): | |
os.makedirs(self.path) | |
def sample_replay_memory(self, batch_size): | |
idxs = np.random.random_integers(0, len(self.replay_memory)-1, batch_size) | |
return [self.replay_memory[i] for i in idxs] | |
def get_validation_set(self): | |
memories = self.sample_replay_memory(self.batch_size) | |
observations_batch, nextObservations_batch, \ | |
actions_batch, reward_batch, nonterminal_batch = self.get_batches_from_memories(memories) | |
q_values = self.generate_targets(observations_batch, | |
nextObservations_batch, | |
actions_batch, | |
reward_batch, | |
nonterminal_batch) | |
return [np.array(observations_batch), np.array(q_values)] | |
def get_epsilon(self): | |
decayed = self.start_epsilon - (self.start_epsilon - self.end_epsilon) / self.decay * self.step | |
return decayed if decayed >= self.end_epsilon else self.end_epsilon | |
def increment_step(self): | |
self.step += 1 | |
self.game_step += 1 | |
def try_loading_previous_version(self): | |
# Find latest version | |
if os.path.isfile(self.version_file): | |
with open(self.version_file, "rb") as ipt: | |
self.sub_version = pickle.load(ipt) | |
if os.path.isfile(self.memory_file(self.sub_version)): | |
with open(self.memory_file(self.sub_version), "rb") as input: | |
self.replay_memory = pickle.load(input) | |
print("Loaded previous memory successfully") | |
if os.path.isfile(self.parameters_file(self.sub_version)): | |
with open(self.parameters_file(self.sub_version), "rb") as input: | |
self.step = pickle.load(input) | |
self.sub_version = pickle.load(input) | |
self.epoch = pickle.load(input) | |
print("Restarting in subversion {} from step {}, epoch {}".format(self.sub_version, self.step, self.epoch)) | |
if os.path.isfile(self.model_file(self.sub_version)): | |
custom_objects = {"huber_loss": self.huber_loss} | |
self.model = load_model(self.model_file(self.sub_version), custom_objects=custom_objects) | |
def __init__(self, layout=None, remote=0, layoutName="mediumGrid", saveFile="testfile", decay=300000, **args): | |
ReinforcementAgent.__init__(self, **args) | |
signal.signal(signal.SIGINT, self.cleanup) | |
self.model = None | |
self.replay_memory = None | |
self.nb_episodes_between_backups = 5000 | |
self. setup_filesystem(remote, layoutName, saveFile) | |
self.sub_version = 0 | |
self.huber_delta = tf.constant(1., dtype="float32") | |
self.window_length = 2 | |
self.input_shape = (self.window_length, layout.width, layout.height) | |
self.batch_size = 32 | |
self.learning_rate = .000025 | |
self.gamma = 1 | |
self.memory_size = 300000 | |
self.nb_actions = 5 | |
self.nb_warmup_steps = 1000 | |
self.nb_max_rnd_start_steps = 10 | |
self.nb_rnd_start_steps = 0 | |
self.last_observations = deque(maxlen=self.window_length) | |
self.decay = decay | |
self.step = 0 | |
self.game_step = 0 | |
self.epoch = 0 | |
self.final_score = 0 | |
self.last_loss = 0 | |
self.start_epsilon = 1.0 | |
self.end_epsilon = 0.1 | |
self.ipt = Input(shape=self.input_shape) | |
self.permute = Permute((2, 3, 1))(self.ipt) | |
self.c1 = Convolution2D(32, (3, 3), strides=(1, 1), bias_initializer=Zeros(), kernel_initializer=VarianceScaling(scale=2))(self.permute) | |
self.a1 = Activation('relu')(self.c1) | |
self.c2 = Convolution2D(64, (3, 3), strides=(1, 1), bias_initializer=Zeros(), kernel_initializer=VarianceScaling(scale=2))(self.a1) | |
self.a2 = Activation('relu')(self.c2) | |
self.c3 = Convolution2D(64, (3, 3), strides=(1, 1), bias_initializer=Zeros(), kernel_initializer=VarianceScaling(scale=2))(self.a2) | |
self.a3 = Activation('relu')(self.c3) | |
self.flat = Flatten()(self.a3) | |
self.dense = Dense(self.nb_actions, bias_initializer=Zeros(), kernel_initializer=VarianceScaling(scale=2))(self.flat) | |
self.out = Activation('linear')(self.dense) | |
self.try_loading_previous_version() | |
if self.model is None: | |
self.model = Model(inputs=self.ipt, outputs=self.out) | |
self.model.compile(loss=self.huber_loss, optimizer=Adam(lr=self.learning_rate)) | |
if self.replay_memory is None: | |
self.replay_memory = deque(maxlen=self.memory_size) | |
tb = TensorBoardWrap(generator=self.get_validation_set, log_dir=self.log_dir, write_graph=True, write_grads=True, histogram_freq=100, batch_size=self.batch_size) | |
self.callbacks = [tb] | |
self.callbacks = CallbackList(callbacks=self.callbacks) | |
self.callbacks.set_model(self.model) | |
def getAction(self, state): | |
# Basically the forward pass. | |
action = None | |
observation = state.data.asArray() | |
self.last_observations.append(observation) | |
if self.isInTraining(): | |
# Take the epsilon greedy action | |
eps = self.get_epsilon() | |
else: | |
# Take the greedy action | |
eps = 0 | |
rnd = np.random.uniform(0, 1) | |
if self.game_step < self.nb_rnd_start_steps or rnd < eps: | |
action = np.random.choice(self.getLegalActions(state)) | |
else: | |
q_values = self.model.predict_on_batch(np.reshape([obs for obs in self.last_observations], (1, ) + self.input_shape))[0] | |
sorted_indices_decreasing = np.argsort(q_values)[::-1] | |
for idx in sorted_indices_decreasing: | |
if idx in Actions.actionsAsIndices(self.getLegalActions(state)): | |
action = Actions._directionsAsList[idx][0] | |
break | |
self.increment_step() | |
self.doAction(state, action) | |
return action | |
def generate_targets(self, observations_batch, nextObservations_batch, actions_batch, reward_batch, nonterminal_batch): | |
# First we predict on batch to get the actual q_values, | |
q_values = self.model.predict_on_batch(np.array(observations_batch)) | |
q_values_next = self.model.predict_on_batch(np.array(nextObservations_batch)) | |
# Then we update the q_values for the action we took | |
for idx, q in enumerate(q_values): | |
q[actions_batch[idx]] = reward_batch[idx] + self.gamma * q_values_next[idx, actions_batch[idx]] * nonterminal_batch[idx] | |
return q_values | |
def get_batches_from_memories(self, memories): | |
observations_batch = [] | |
actions_batch = [] | |
nextObservations_batch = [] | |
reward_batch = [] | |
nonterminal_batch = [] | |
for memory in memories: | |
observations_batch.append(memory['observations']) | |
actions_batch.append(memory['actions']) | |
nextObservations_batch.append(memory['nextObservations']) | |
reward_batch.append(memory['rewards']) | |
nonterminal_batch.append(memory['nonterminal']) | |
return (observations_batch, nextObservations_batch, actions_batch, reward_batch, nonterminal_batch) | |
def update(self, state, action, nextState, reward): | |
# Let's jump out here when we have not seen enough states yet to fill our window | |
if len(self.last_observations) != self.window_length: | |
return | |
nextObservation = None if nextState is None else nextState.data.asArray() | |
self.replay_memory.append( | |
{ | |
"observations": np.reshape(list(self.last_observations), self.input_shape), | |
"actions": Actions.actionsAsIndices([action])[0], | |
"nextObservations": np.reshape(list(self.last_observations)[1:] + [nextObservation], self.input_shape), | |
"rewards": reward, | |
"nonterminal": 1 | |
}) | |
# Let's jump out here when we don't have enough samples in our replay memory yet. | |
if self.step < self.nb_warmup_steps: | |
return | |
memories = self.sample_replay_memory(self.batch_size) | |
observations_batch, nextObservations_batch, \ | |
actions_batch, reward_batch, nonterminal_batch = self.get_batches_from_memories(memories) | |
q_values = self.generate_targets(observations_batch, | |
nextObservations_batch, | |
actions_batch, | |
reward_batch, | |
nonterminal_batch) | |
# And train on this batch. | |
self.last_loss = self.model.train_on_batch(x=np.array(observations_batch), y=np.array(q_values)) | |
def saveEverything(self): | |
# Save the model | |
self.model.save(self.model_file(self.sub_version)) | |
with open(self.memory_file(self.sub_version), "w") as output: | |
pickle.dump(self.replay_memory, output, pickle.HIGHEST_PROTOCOL) | |
with open(self.parameters_file(self.sub_version), "w") as output: | |
pickle.dump(self.step, output, pickle.HIGHEST_PROTOCOL) | |
pickle.dump(self.sub_version, output, pickle.HIGHEST_PROTOCOL) | |
pickle.dump(self.epoch, output, pickle.HIGHEST_PROTOCOL) | |
with open(self.version_file, "w") as opt: | |
pickle.dump(self.sub_version, opt, pickle.HIGHEST_PROTOCOL) | |
# After saving the newest state, we delete the older state to save some space... | |
if self.sub_version > 0: | |
try: | |
os.remove(self.model_file(self.sub_version - 1)) | |
os.remove(self.memory_file(self.sub_version - 1)) | |
os.remove(self.parameters_file(self.sub_version - 1)) | |
except: | |
print("Previous version was already deleted.") | |
pass | |
print("Saved Model, dumped memory and parameters to pickle file, version {}.".format(self.sub_version)) | |
self.sub_version += 1 | |
def cleanup(self, sig, frame): | |
self.saveEverything() | |
sys.exit(0) | |
def startEpisode(self): | |
ReinforcementAgent.startEpisode(self) | |
self.nb_rnd_start_steps = np.random.random_integers(self.window_length, self.nb_max_rnd_start_steps) | |
self.game_step = 0 | |
self.callbacks.on_epoch_begin(self.epoch) | |
self.last_observations.clear() | |
def stopEpisode(self): | |
ReinforcementAgent.stopEpisode(self) | |
logs = {"reward": self.final_score, "epsilon": self.get_epsilon(), "loss": self.last_loss} | |
self.callbacks.on_epoch_end(self.epoch, logs=logs) | |
if self.episodesSoFar % self.nb_episodes_between_backups == 0: | |
self.saveEverything() | |
self.epoch += 1 | |
def final(self, state): | |
ReinforcementAgent.final(self, state) | |
self.replay_memory[-1]['nonterminal'] = 0 | |
self.final_score = state.getScore() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment