Skip to content

Instantly share code, notes, and snippets.

Forked from EderSantana/
Created November 9, 2017 01:38
Show Gist options
  • Save dellis23/7aff5e20d8021e9b96cb67078f1f43ad to your computer and use it in GitHub Desktop.
Save dellis23/7aff5e20d8021e9b96cb67078f1f43ad to your computer and use it in GitHub Desktop.
Keras plays catch - a single file Reinforcement Learning example

Code for Keras plays catch blog post




  1. Generate figures
  1. Make gif
ffmpeg -i %03d.png output.gif -vf fps=1

Alternatively, check cadurosar ipython notebook, there you should run cell 2 before cell 1.


  • Prior supervised learning and Keras knowledge
  • Python science stack (numpy, scipy, matplotlib) - Install Anaconda!
  • Theano or Tensorflow
  • Keras (last testest on commit b0303f03ff03)
  • ffmpeg (optional)


This code is released under MIT license. (Note that Deep Q-Learning has its own patent by Google)

import json
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import sgd
class Catch(object):
For each run of the "game", a piece of fruit will be dropped in a random
location. It will continue until it makes its way to the bottom of the
screen. When it reaches the bottom, the game will be considered over.
No reward will be given until the end. At the end of one run of fruit
dropping, the player will either receive a reward of 1 if the basket
overlapped with the fruit, or -1 if the basket did not overlap with the
def __init__(self, grid_size=10):
self.grid_size = grid_size
def _update_state(self, action):
# Our neural network's output values will range from 0 to 2,
# corresponding to the choice to either move left, stay in the current
# position, or move right. For simplicity's sake, we can convert these
# to -1, 0, and 1, and then add them to the current value to get an
# updated position.
if action == 0: # left
action = -1
elif action == 1: # stay
action = 0
action = 1 # right
# Move the fruit and basket
self.basket_col = min(max(1, self.basket_col + action),
self.fruit_row += 1
def _draw_state(self):
im_size = (self.grid_size,)*2
canvas = np.zeros(im_size)
canvas[self.fruit_row, self.fruit_col] = 1 # draw fruit
canvas[-1, self.basket_col-1:self.basket_col + 2] = 1 # draw basket
return canvas
def _get_reward(self):
if self._is_over():
if abs(self.fruit_col - self.basket_col) <= 1:
return 1
return -1
return 0
def _is_over(self):
if self.fruit_row == self.grid_size-1:
return True
return False
def observe(self):
canvas = self._draw_state()
return canvas.reshape((1, -1))
def act(self, action):
reward = self._get_reward()
game_over = self._is_over()
return self.observe(), reward, game_over
def reset(self):
self.fruit_row = 0
self.fruit_col = np.random.randint(0, self.grid_size - 1)
self.basket_col = np.random.randint(1, self.grid_size - 2)
class ExperienceReplay(object):
def __init__(self, max_memory=100, discount=.9):
self.max_memory = max_memory
self.memory = list() = discount
def remember(self, states, game_over):
# memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]
self.memory.append([states, game_over])
if len(self.memory) > self.max_memory:
del self.memory[0]
def get_batch(self, model, batch_size=10):
len_memory = len(self.memory)
num_actions = model.output_shape[-1]
env_dim = self.memory[0][0][0].shape[1]
inputs = np.zeros((min(len_memory, batch_size), env_dim))
targets = np.zeros((inputs.shape[0], num_actions))
for i, idx in enumerate(np.random.randint(0, len_memory,
state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]
game_over = self.memory[idx][1]
inputs[i:i+1] = state_t
# There should be no target values for actions not taken.
# Thou shalt not correct actions not taken #deep
targets[i] = model.predict(state_t)[0]
Q_sa = np.max(model.predict(state_tp1)[0])
if game_over: # if game_over is True
targets[i, action_t] = reward_t
# reward_t + gamma * max_a' Q(s', a')
targets[i, action_t] = reward_t + * Q_sa
return inputs, targets
if __name__ == "__main__":
# parameters
epsilon = .1 # exploration
num_actions = 3 # [move_left, stay, move_right]
epoch = 1000
max_memory = 500
hidden_size = 100
batch_size = 50
grid_size = 10
model = Sequential()
model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
model.add(Dense(hidden_size, activation='relu'))
model.compile(sgd(lr=.2), "mse")
# If you want to continue training from a previous model, just uncomment the line bellow
# model.load_weights("model.h5")
# Define environment/game
env = Catch(grid_size)
# Initialize experience replay object
exp_replay = ExperienceReplay(max_memory=max_memory)
# Train
win_cnt = 0
for e in range(epoch):
loss = 0.
game_over = False
# get initial input
input_t = env.observe()
while not game_over:
input_tm1 = input_t
# get next action
if np.random.rand() <= epsilon:
action = np.random.randint(0, num_actions, size=1)
q = model.predict(input_tm1)
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
if reward == 1:
win_cnt += 1
# store experience
exp_replay.remember([input_tm1, action, reward, input_t], game_over)
# adapt model
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)
loss += model.train_on_batch(inputs, targets)
print("Epoch {:03d}/999 | Loss {:.4f} | Win count {}".format(e, loss, win_cnt))
# Save trained model weights and architecture, this will be used by the visualization code
model.save_weights("model.h5", overwrite=True)
with open("model.json", "w") as outfile:
json.dump(model.to_json(), outfile)
import json
import matplotlib.pyplot as plt
import numpy as np
from keras.models import model_from_json
from qlearn import Catch
if __name__ == "__main__":
# Make sure this grid size matches the value used from training
grid_size = 10
with open("model.json", "r") as jfile:
model = model_from_json(json.load(jfile))
model.compile("sgd", "mse")
# Define environment, game
env = Catch(grid_size)
c = 0
for e in range(10):
loss = 0.
game_over = False
# get initial input
input_t = env.observe()
interpolation='none', cmap='gray')
plt.savefig("%03d.png" % c)
c += 1
while not game_over:
input_tm1 = input_t
# get next action
q = model.predict(input_tm1)
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
interpolation='none', cmap='gray')
plt.savefig("%03d.png" % c)
c += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment