Created
September 18, 2018 23:05
Star
You must be signed in to star a gist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import datetime | |
import gym | |
import gym.spaces | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
from keras.models import Sequential | |
from keras.layers import Dense, Conv2D, Flatten | |
from keras import optimizers | |
from collections import deque | |
LOG_LEVEL = 2 | |
np.random.seed(int(input("Training seed? "))) | |
# Like original Atari Breakout, but fixes the problem of some frames being skipped while playing | |
env = gym.make('BreakoutDeterministic-v4') | |
""" | |
Preprocessing steps: | |
Convert image to grayscale using the average of the R, G, B values at the pixel | |
Downscale the image by 2, converting the 210 by 160 image into 105 by 80 | |
""" | |
def preprocess(img): | |
img = np.mean(img, axis=2).astype(np.uint8) # converts the images from color to grayscale | |
img = img[::2, ::2] # splices the image for every other pixel to reduce image size by 2 on both axes | |
return img # returns an image of size 105x80 pixels | |
def create_model(): | |
model = Sequential() | |
model.add(Conv2D(32, (8, 8), strides=(4, 4), activation='relu', input_shape=(105, 80, 4))) | |
model.add(Conv2D(64, (4, 4), strides=(2, 2), activation='relu')) | |
model.add(Conv2D(64, (3, 3), strides=(1, 1), activation='relu')) | |
model.add(Flatten()) | |
model.add(Dense(512, activation='relu')) | |
model.add(Dense(4)) | |
return model | |
# Performs one batch on training on the model | |
def train_model(model, target_model, minibatch, gamma): | |
# define targets for each SARS transition | |
q_targets = np.empty([len(minibatch), 4]) | |
batch_states = np.empty([len(minibatch), 105, 80, 4]) | |
i = 0 | |
for state, action, reward, next_state in minibatch: | |
batch_states[i] = state | |
q_pred = target_model.predict(state) | |
q_pred_next = target_model.predict(next_state) | |
q_targets[i, :] = q_pred | |
q_targets[i, action] = reward + gamma * np.max(q_pred_next) | |
# print(q_targets[i], action, q_pred_next, reward) | |
i += 1 | |
if LOG_LEVEL > 1: print("Loss:",model.train_on_batch(batch_states, q_targets)) | |
return model | |
# Hyperparameters | |
rendering = 0 | |
q_tracking = 0 | |
epsilon = 1. | |
epsilon_min = 0.1 | |
exploration_steps = 100000 | |
epsilon_anneal = (epsilon - epsilon_min) / exploration_steps | |
gamma = 0.90 | |
episodes = 10000 | |
batch_size = 32 | |
lr = 0.00025 | |
max_updates = 500 # amount of steps to play until the next update of the target network | |
training_model = create_model() | |
target_model = create_model() | |
opt = optimizers.RMSprop(lr=lr) | |
training_model.compile(loss='mse', optimizer=opt, metrics=['mae']) | |
target_model.compile(loss='mse', optimizer=opt, metrics=['mae']) | |
replay_memory = deque(maxlen=100000) | |
t_list = [] | |
r_list = [] | |
q_data = [] | |
updates = 0 | |
for ep in range(episodes): | |
obs = env.reset() | |
done = False | |
num_frame = 1 | |
reward_sum = 0 | |
ep_reward_sum = 0 | |
last_a = np.random.randint(4) | |
state = [] | |
last_state = [] | |
last_reward = 0 | |
timestep = 1 | |
q_data.clear() | |
while not done: | |
obs, reward, done, _ = env.step(last_a) | |
state.append(preprocess(obs)) | |
reward_sum += reward | |
ep_reward_sum += reward | |
num_frame += 1 | |
if rendering: env.render() | |
if num_frame > 4: # choose a new action using the state every 4 frames | |
# Record s, a, r, s' transition for last state (since we now know new state) | |
if timestep is not 1: replay_memory.append((np.stack(last_state, axis=-1).reshape([1, 105, 80, 4]), last_a, reward, np.stack(state, axis=-1).reshape([1, 105, 80, 4]))) | |
# Select an action | |
Q_pred = training_model.predict(np.stack(state, axis=-1).reshape([1, 105, 80, 4])) | |
if len(q_data) > 200: q_data.pop(0) | |
q_data.append(np.max(Q_pred)) | |
if q_tracking: | |
plt.gcf().clear() | |
plt.plot(q_data, 'b-') | |
plt.gca().set_xbound([0, 200]) | |
plt.gca().set_ybound([0, 3]) | |
plt.pause(0.002) | |
if np.random.rand() < epsilon: | |
a = env.action_space.sample() | |
else: | |
a = np.argmax(Q_pred) | |
if epsilon > 0.1: epsilon -= epsilon_anneal | |
num_frame = 1 | |
last_a = a | |
last_state = state | |
last_reward = reward_sum | |
reward_sum = 0 | |
state = [] | |
timestep += 1 | |
# Training the model using experience replay | |
if len(replay_memory) < batch_size: continue # Don't train until we can do at least 1 batch worth of information | |
minibatch = random.sample(replay_memory, batch_size) | |
training_model = train_model(training_model, target_model, minibatch, gamma) | |
updates += 1 | |
if updates > max_updates: | |
updates = 0 | |
if LOG_LEVEL > 1: print("Updated Target Network") | |
training_model.save('saved_models/model_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.h5') | |
target_model.set_weights(training_model.get_weights()) | |
r_list.append(ep_reward_sum) | |
t_list.append(timestep) | |
if LOG_LEVEL > 0: print("Episode", ep + 1, "completed. Reward collected:", ep_reward_sum) | |
env.close() | |
# save the final trained model | |
training_model.save('saved_models/model_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.h5') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment