Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Keras plays catch - a single file Reinforcement Learning example

Code for Keras plays catch blog post

Train

python qlearn.py

Test

  1. Generate figures
python test.py
  1. Make gif
ffmpeg -i %03d.png output.gif -vf fps=1

Alternatively, check cadurosar ipython notebook, there you should run cell 2 before cell 1.

Requirements

  • Prior supervised learning and Keras knowledge
  • Python science stack (numpy, scipy, matplotlib) - Install Anaconda!
  • Theano or Tensorflow
  • Keras (last testest on commit b0303f03ff03)
  • ffmpeg (optional)

License

This code is released under MIT license. (Note that Deep Q-Learning has its own patent by Google)

import json
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import sgd
class Catch(object):
def __init__(self, grid_size=10):
self.grid_size = grid_size
self.reset()
def _update_state(self, action):
"""
Input: action and states
Ouput: new states and reward
"""
state = self.state
if action == 0: # left
action = -1
elif action == 1: # stay
action = 0
else:
action = 1 # right
f0, f1, basket = state[0]
new_basket = min(max(1, basket + action), self.grid_size-1)
f0 += 1
out = np.asarray([f0, f1, new_basket])
out = out[np.newaxis]
assert len(out.shape) == 2
self.state = out
def _draw_state(self):
im_size = (self.grid_size,)*2
state = self.state[0]
canvas = np.zeros(im_size)
canvas[state[0], state[1]] = 1 # draw fruit
canvas[-1, state[2]-1:state[2] + 2] = 1 # draw basket
return canvas
def _get_reward(self):
fruit_row, fruit_col, basket = self.state[0]
if fruit_row == self.grid_size-1:
if abs(fruit_col - basket) <= 1:
return 1
else:
return -1
else:
return 0
def _is_over(self):
if self.state[0, 0] == self.grid_size-1:
return True
else:
return False
def observe(self):
canvas = self._draw_state()
return canvas.reshape((1, -1))
def act(self, action):
self._update_state(action)
reward = self._get_reward()
game_over = self._is_over()
return self.observe(), reward, game_over
def reset(self):
n = np.random.randint(0, self.grid_size-1, size=1)
m = np.random.randint(1, self.grid_size-2, size=1)
self.state = np.asarray([0, n, m])[np.newaxis]
class ExperienceReplay(object):
def __init__(self, max_memory=100, discount=.9):
self.max_memory = max_memory
self.memory = list()
self.discount = discount
def remember(self, states, game_over):
# memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]
self.memory.append([states, game_over])
if len(self.memory) > self.max_memory:
del self.memory[0]
def get_batch(self, model, batch_size=10):
len_memory = len(self.memory)
num_actions = model.output_shape[-1]
env_dim = self.memory[0][0][0].shape[1]
inputs = np.zeros((min(len_memory, batch_size), env_dim))
targets = np.zeros((inputs.shape[0], num_actions))
for i, idx in enumerate(np.random.randint(0, len_memory,
size=inputs.shape[0])):
state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]
game_over = self.memory[idx][1]
inputs[i:i+1] = state_t
# There should be no target values for actions not taken.
# Thou shalt not correct actions not taken #deep
targets[i] = model.predict(state_t)[0]
Q_sa = np.max(model.predict(state_tp1)[0])
if game_over: # if game_over is True
targets[i, action_t] = reward_t
else:
# reward_t + gamma * max_a' Q(s', a')
targets[i, action_t] = reward_t + self.discount * Q_sa
return inputs, targets
if __name__ == "__main__":
# parameters
epsilon = .1 # exploration
num_actions = 3 # [move_left, stay, move_right]
epoch = 1000
max_memory = 500
hidden_size = 100
batch_size = 50
grid_size = 10
model = Sequential()
model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
model.add(Dense(hidden_size, activation='relu'))
model.add(Dense(num_actions))
model.compile(sgd(lr=.2), "mse")
# If you want to continue training from a previous model, just uncomment the line bellow
# model.load_weights("model.h5")
# Define environment/game
env = Catch(grid_size)
# Initialize experience replay object
exp_replay = ExperienceReplay(max_memory=max_memory)
# Train
win_cnt = 0
for e in range(epoch):
loss = 0.
env.reset()
game_over = False
# get initial input
input_t = env.observe()
while not game_over:
input_tm1 = input_t
# get next action
if np.random.rand() <= epsilon:
action = np.random.randint(0, num_actions, size=1)
else:
q = model.predict(input_tm1)
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
if reward == 1:
win_cnt += 1
# store experience
exp_replay.remember([input_tm1, action, reward, input_t], game_over)
# adapt model
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)
loss += model.train_on_batch(inputs, targets)[0]
print("Epoch {:03d}/999 | Loss {:.4f} | Win count {}".format(e, loss, win_cnt))
# Save trained model weights and architecture, this will be used by the visualization code
model.save_weights("model.h5", overwrite=True)
with open("model.json", "w") as outfile:
json.dump(model.to_json(), outfile)
import json
import matplotlib.pyplot as plt
import numpy as np
from keras.models import model_from_json
from qlearn import Catch
if __name__ == "__main__":
# Make sure this grid size matches the value used fro training
grid_size = 10
with open("model.json", "r") as jfile:
model = model_from_json(json.load(jfile))
model.load_weights("model.h5")
model.compile("sgd", "mse")
# Define environment, game
env = Catch(grid_size)
c = 0
for e in range(10):
loss = 0.
env.reset()
game_over = False
# get initial input
input_t = env.observe()
plt.imshow(input_t.reshape((grid_size,)*2),
interpolation='none', cmap='gray')
plt.savefig("%03d.png" % c)
c += 1
while not game_over:
input_tm1 = input_t
# get next action
q = model.predict(input_tm1)
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
plt.imshow(input_t.reshape((grid_size,)*2),
interpolation='none', cmap='gray')
plt.savefig("%03d.png" % c)
c += 1

Really nice reinforcement learning example, I made a ipython notebook version of the test that instead of saving the figure it refreshes itself, its not that good (you have to execute cell 2 before cell 1) but could be usefull if you want to easily see the evolution of the model.

https://gist.github.com/cadurosar/bd54c723c1d6335a43c8

PS: Revision 6 has some bugs like targest instead of targets

Owner

EderSantana commented Mar 19, 2016

@cadurosar Tkx for the contribution! I added a link to your code on README.

Hi Eder,

Thanks for the really useful keras example. I have a question on your experience replay implementation.

The loss is calculated between the output of experience replay samples (lets call it OER) and calculated targets. Now, the action chosen at OER is the exact same as the ones that were stored in the experience. This is implemented indirectly in line 103 targets[i, action_t] = reward_t + self.discount * Q_sa by inserting the target in appropriate column. When you calculate MSE loss between the targets tensor and network output (each of shape inputs.shape[0]*num_actions), wouldn't all the values in network output other than the value corresponding to the non-zero target add to the loss? Since each target row has only one non-zero value.

Really hope I'm not misreading your code or being confusing.

k3nt0 commented May 8, 2016

Hello, Eder

Thanks for the nice reinforcement example. I could study about reinforcement learning efficiently.

I'm not good at English, but I hope it's understable to you.

By the way, I have an Idea for more good train. I think the basket should wait under the fruit before it get fall to the ground. So I changed definition of _get_reward() like this.

def _get_reward_2nd(self):
        fruit_row, fruit_col, basket = self.state[0]
        if abs(fruit_col - basket) <= 1:
            return 1
        else:
            return -1

After changing definition, the win_cnt soared like this,

Epoch 996/999 | Loss 0.0043 | Win count 7166
Epoch 997/999 | Loss 0.0051 | Win count 7173
Epoch 998/999 | Loss 0.0048 | Win count 7182
Epoch 999/999 | Loss 0.0058 | Win count 7186

I'm sorry if you have already know this. I hope my opinion would be your help.

thanks for sharing

qrpike commented May 24, 2016 edited

Copy pasted and got:

$ python qlearn.py 
Using Theano backend.
Traceback (most recent call last):
  File "qlearn.py", line 164, in <module>
    loss += model.train_on_batch(inputs, targets)[0]
IndexError: too many indices for array

Sohojoe commented May 24, 2016 edited

I get the following error; running on Keras (1.0.3) pythion 2.7

  File "qlearn.py", line 164, in <module>
    loss += model.train_on_batch(inputs, targets)[0]
IndexError: invalid index to scalar variable.

i was able to fix it by removing the [0] from line 164
loss += model.train_on_batch(inputs, targets)

qrpike commented May 25, 2016

@Sohojoe Sorry, I don't usually work in python. THANK YOU for the solution though. Works fine now

cloudyangyy commented Jun 6, 2016 edited

I'm learning RL, thanks for your share.
If I set grid_size =50 ,It could not work well.
I have tryed Convolutional Layer,It also not work well.
Do you have any idea? Thanks.

danijar commented Jul 17, 2016

I think Q_sa = np.max(model.predict(state_tp1)[0]) should be predicted by the "target network", meaning using the weights from the previous timestep. Am I mistaken of is this missing from the code?

luli395 commented Dec 12, 2016 edited

The code is nice and clear. In the code, the original game assumes the fruit falls straight down from the top, and does not move horizontally. I wonder what if the fruit move horizontally while falling, so I added few codes to make the fruit move horizontally. The result is interesting: the algorithm can easily learn good policy to catch the fruit even if the fruit move randomly along x-axis while falling as well.

def _update_state(self, action):
"""
Input: action and states
Ouput: new states and reward
"""
state = self.state
if action == 0: # left
action = -1
elif action == 1: # stay
action = 0
else:
action = 1 # right
f0, f1, basket = state[0]
new_basket = min(max(1, basket + action), self.grid_size-1)
df1=np.random.randint(0,3,1)
if df1==0:
f1-=1
elif df1==2:
f1+=1
else:
pass

    if f1>self.grid_size-1:
        f1=self.grid_size-1
    if f1<0:
        f1=0
        
    f0 += 1
    
    out = np.asarray([f0, f1, new_basket])
    out = out[np.newaxis]

First of all, Thanks for your sharing!!!!

I tried to implement your code, but there happened error such that:


IOError Traceback (most recent call last)
in ()
11
12
---> 13 with open("model.json", "r") as jfile:
14 model = model_from_json(json.load(jfile))
15 model.load_weights("model.h5")

IOError: [Errno 2] No such file or directory: 'model.json'

I did googling for solving this problem... but,, I failed..

could you help me?

arr28 commented Feb 3, 2017

@wonchul-kim: Looks like you ran test.py before running qlearn.py.

Hey, so I was playing around with the code and I have a big (pun intended) problem with the neural network output. That is, model.predicts very quickly diverges to infinity (even on the first batch). Did you ever encounter something like this or any ideas where this could be coming from. I am using the code to do reinforcement on a different game.

suryavanshi commented Mar 24, 2017 edited

How can I incrementally fine tune a trained model if I change the the falling speed or some other parameters of the game slightly?

@qrpike what was the solution of your problem? You should write answer instead 'say thanks'!

I've created a Paddle Ball game using Deep Q Learning based on this work. Thanks Eder for the basics! It is really helpful for a biginner like me.
https://github.com/azhar2205/paddle-ball-using-dqlearn

Where does one get the code for these simple games like catch ?
Thank you

@EderSantana thanks for the great and informative project!
@cadurosar thank you for contributing the notebook version!
I made a version with more comments and explanations for teaching purposes. It can be run cell by cell like a walk through tutorial. It runs in python 2.7 as well as 3.5 and fixes some of the errors mentioned in the comments above. You can find it here: https://github.com/JannesKlaas/sometimes_deep_sometimes_learning/blob/master/reinforcement.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment