Skip to content

Instantly share code, notes, and snippets.

@EderSantana
Last active June 22, 2024 17:07
Show Gist options
  • Save EderSantana/c7222daa328f0e885093 to your computer and use it in GitHub Desktop.
Save EderSantana/c7222daa328f0e885093 to your computer and use it in GitHub Desktop.
Keras plays catch - a single file Reinforcement Learning example

Code for Keras plays catch blog post

Train

python qlearn.py

Test

  1. Generate figures
python test.py
  1. Make gif
ffmpeg -i %03d.png output.gif -vf fps=1

Alternatively, check cadurosar ipython notebook, there you should run cell 2 before cell 1.

Requirements

  • Prior supervised learning and Keras knowledge
  • Python science stack (numpy, scipy, matplotlib) - Install Anaconda!
  • Theano or Tensorflow
  • Keras (last testest on commit b0303f03ff03)
  • ffmpeg (optional)

License

This code is released under MIT license. (Note that Deep Q-Learning has its own patent by Google)

import json
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import sgd
class Catch(object):
def __init__(self, grid_size=10):
self.grid_size = grid_size
self.reset()
def _update_state(self, action):
"""
Input: action and states
Ouput: new states and reward
"""
state = self.state
if action == 0: # left
action = -1
elif action == 1: # stay
action = 0
else:
action = 1 # right
f0, f1, basket = state[0]
new_basket = min(max(1, basket + action), self.grid_size-1)
f0 += 1
out = np.asarray([f0, f1, new_basket])
out = out[np.newaxis]
assert len(out.shape) == 2
self.state = out
def _draw_state(self):
im_size = (self.grid_size,)*2
state = self.state[0]
canvas = np.zeros(im_size)
canvas[state[0], state[1]] = 1 # draw fruit
canvas[-1, state[2]-1:state[2] + 2] = 1 # draw basket
return canvas
def _get_reward(self):
fruit_row, fruit_col, basket = self.state[0]
if fruit_row == self.grid_size-1:
if abs(fruit_col - basket) <= 1:
return 1
else:
return -1
else:
return 0
def _is_over(self):
if self.state[0, 0] == self.grid_size-1:
return True
else:
return False
def observe(self):
canvas = self._draw_state()
return canvas.reshape((1, -1))
def act(self, action):
self._update_state(action)
reward = self._get_reward()
game_over = self._is_over()
return self.observe(), reward, game_over
def reset(self):
n = np.random.randint(0, self.grid_size-1, size=1)
m = np.random.randint(1, self.grid_size-2, size=1)
self.state = np.asarray([0, n, m])[np.newaxis]
class ExperienceReplay(object):
def __init__(self, max_memory=100, discount=.9):
self.max_memory = max_memory
self.memory = list()
self.discount = discount
def remember(self, states, game_over):
# memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]
self.memory.append([states, game_over])
if len(self.memory) > self.max_memory:
del self.memory[0]
def get_batch(self, model, batch_size=10):
len_memory = len(self.memory)
num_actions = model.output_shape[-1]
env_dim = self.memory[0][0][0].shape[1]
inputs = np.zeros((min(len_memory, batch_size), env_dim))
targets = np.zeros((inputs.shape[0], num_actions))
for i, idx in enumerate(np.random.randint(0, len_memory,
size=inputs.shape[0])):
state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]
game_over = self.memory[idx][1]
inputs[i:i+1] = state_t
# There should be no target values for actions not taken.
# Thou shalt not correct actions not taken #deep
targets[i] = model.predict(state_t)[0]
Q_sa = np.max(model.predict(state_tp1)[0])
if game_over: # if game_over is True
targets[i, action_t] = reward_t
else:
# reward_t + gamma * max_a' Q(s', a')
targets[i, action_t] = reward_t + self.discount * Q_sa
return inputs, targets
if __name__ == "__main__":
# parameters
epsilon = .1 # exploration
num_actions = 3 # [move_left, stay, move_right]
epoch = 1000
max_memory = 500
hidden_size = 100
batch_size = 50
grid_size = 10
model = Sequential()
model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
model.add(Dense(hidden_size, activation='relu'))
model.add(Dense(num_actions))
model.compile(sgd(lr=.2), "mse")
# If you want to continue training from a previous model, just uncomment the line bellow
# model.load_weights("model.h5")
# Define environment/game
env = Catch(grid_size)
# Initialize experience replay object
exp_replay = ExperienceReplay(max_memory=max_memory)
# Train
win_cnt = 0
for e in range(epoch):
loss = 0.
env.reset()
game_over = False
# get initial input
input_t = env.observe()
while not game_over:
input_tm1 = input_t
# get next action
if np.random.rand() <= epsilon:
action = np.random.randint(0, num_actions, size=1)
else:
q = model.predict(input_tm1)
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
if reward == 1:
win_cnt += 1
# store experience
exp_replay.remember([input_tm1, action, reward, input_t], game_over)
# adapt model
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)
loss += model.train_on_batch(inputs, targets)[0]
print("Epoch {:03d}/999 | Loss {:.4f} | Win count {}".format(e, loss, win_cnt))
# Save trained model weights and architecture, this will be used by the visualization code
model.save_weights("model.h5", overwrite=True)
with open("model.json", "w") as outfile:
json.dump(model.to_json(), outfile)
import json
import matplotlib.pyplot as plt
import numpy as np
from keras.models import model_from_json
from qlearn import Catch
if __name__ == "__main__":
# Make sure this grid size matches the value used fro training
grid_size = 10
with open("model.json", "r") as jfile:
model = model_from_json(json.load(jfile))
model.load_weights("model.h5")
model.compile("sgd", "mse")
# Define environment, game
env = Catch(grid_size)
c = 0
for e in range(10):
loss = 0.
env.reset()
game_over = False
# get initial input
input_t = env.observe()
plt.imshow(input_t.reshape((grid_size,)*2),
interpolation='none', cmap='gray')
plt.savefig("%03d.png" % c)
c += 1
while not game_over:
input_tm1 = input_t
# get next action
q = model.predict(input_tm1)
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
plt.imshow(input_t.reshape((grid_size,)*2),
interpolation='none', cmap='gray')
plt.savefig("%03d.png" % c)
c += 1
@AimoreRRD
Copy link

Thanks for the comment @arr28

@Shreshtha-S
Copy link

I did not understand the code. Could you please explain it line by line? That will be of great help.

@YitzhakSp
Copy link

thanks, very nice code !

@upstarsong
Copy link

Hello, Eder

Thanks for the nice reinforcement example. I could study about reinforcement learning efficiently.

I'm not good at English, but I hope it's understable to you.

By the way, I have an Idea for more good train. I think the basket should wait under the fruit before it get fall to the ground. So I changed definition of _get_reward() like this.

def _get_reward_2nd(self):
        fruit_row, fruit_col, basket = self.state[0]
        if abs(fruit_col - basket) <= 1:
            return 1
        else:
            return -1

After changing definition, the win_cnt soared like this,

Epoch 996/999 | Loss 0.0043 | Win count 7166
Epoch 997/999 | Loss 0.0051 | Win count 7173
Epoch 998/999 | Loss 0.0048 | Win count 7182
Epoch 999/999 | Loss 0.0058 | Win count 7186

I'm sorry if you have already know this. I hope my opinion would be your help.

I think there might be some misunderstanding of your code. You get more count because you will get reward when the fruit has not reached the last row but near to the basket in column. But it doesn't mean the code has got the better training. But I am a beginner too and I have many question of the code. So maybe I am wrong. I just hope it will help you to get better. Thank you.

@EderSantana
Copy link
Author

awesome! thanks for letting me know @upstarsong

@grant541
Copy link

grant541 commented Aug 27, 2022

I did not understand the code. Could you please explain it line by line? That will be of great help.

Agreed, I am looking at the code not understanding the majority of the functionality. I'm glad it has helped some users but I'm relatively new to the subject and without additional comments there little I could hope to learn from this example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment