Skip to content

Instantly share code, notes, and snippets.

@tambetm
Last active May 30, 2016 08:48
Show Gist options
  • Save tambetm/84bffde6eb21789a6fe99c045e7241cb to your computer and use it in GitHub Desktop.
Save tambetm/84bffde6eb21789a6fe99c045e7241cb to your computer and use it in GitHub Desktop.

Used dueling network architecture with Q-learning, as outlined in this paper:

Dueling Network Architectures for Deep Reinforcement Learning
Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas
http://arxiv.org/abs/1511.06581

Refer to code for hyperparameter values.

import argparse
import gym
from gym.spaces import Box, Discrete
from keras.models import Model
from keras.layers import Input, Dense, Lambda
from keras import backend as K
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--hidden_size', type=int, default=100)
parser.add_argument('--min_train', type=int, default=10)
parser.add_argument('--train_repeat', type=int, default=10)
parser.add_argument('--gamma', type=float, default=1)
parser.add_argument('--episodes', type=int, default=200)
parser.add_argument('--max_timesteps', type=int, default=200)
parser.add_argument('--activation', choices=['tanh', 'relu'], default='tanh')
parser.add_argument('--optimizer', choices=['adam', 'rmsprop'], default='adam')
parser.add_argument('--exploration', type=float, default=0.1)
parser.add_argument('--advantage', choices=['naive', 'max', 'avg'], default='naive')
parser.add_argument('--gym_monitor')
parser.add_argument('environment')
args = parser.parse_args()
env = gym.make(args.environment)
assert isinstance(env.observation_space, Box)
assert isinstance(env.action_space, Discrete)
if args.gym_monitor:
env.monitor.start(args.gym_monitor)
x = Input(shape=env.observation_space.shape)
h = Dense(args.hidden_size, activation=args.activation)(x)
y = Dense(env.action_space.n + 1)(h)
if args.advantage == 'avg':
z = Lambda(lambda a: K.expand_dims(a[:,0], dim=-1) + a[:,1:] - K.mean(a[:, 1:], keepdims=True), output_shape=(env.action_space.n,))(y)
elif args.advantage == 'max':
z = Lambda(lambda a: K.expand_dims(a[:,0], dim=-1) + a[:,1:] - K.max(a[:, 1:], keepdims=True), output_shape=(env.action_space.n,))(y)
elif args.advantage == 'naive':
z = Lambda(lambda a: K.expand_dims(a[:,0], dim=-1) + a[:,1:], output_shape=(env.action_space.n,))(y)
else:
assert False
model = Model(input=x, output=z)
model.summary()
model.compile(optimizer='adam', loss='mse')
prestates = []
actions = []
rewards = []
poststates = []
terminals = []
total_reward = 0
for i_episode in xrange(args.episodes):
observation = env.reset()
episode_reward = 0
for t in xrange(args.max_timesteps):
env.render()
if np.random.random() < args.exploration:
action = env.action_space.sample()
else:
s = np.array([observation])
q = model.predict(s, batch_size=1)
#print "q:", q
action = np.argmax(q[0])
#print "action:", action
prestates.append(observation)
actions.append(action)
observation, reward, done, info = env.step(action)
episode_reward += reward
#print "reward:", reward
rewards.append(reward)
poststates.append(observation)
terminals.append(done)
if len(prestates) > args.min_train:
for k in xrange(args.train_repeat):
if len(prestates) > args.batch_size:
indexes = np.random.choice(len(prestates), size=args.batch_size)
else:
indexes = range(len(prestates))
qpre = model.predict(np.array(prestates)[indexes])
qpost = model.predict(np.array(poststates)[indexes])
for i in xrange(len(indexes)):
if terminals[indexes[i]]:
qpre[i, actions[indexes[i]]] = rewards[indexes[i]]
else:
qpre[i, actions[indexes[i]]] = rewards[indexes[i]] + args.gamma * np.amax(qpost[i])
model.train_on_batch(np.array(prestates)[indexes], qpre)
if done:
break
print "Episode {} finished after {} timesteps, episode reward {}".format(i_episode + 1, t + 1, episode_reward)
total_reward += episode_reward
print "Average reward per episode {}".format(total_reward / args.episodes)
if args.gym_monitor:
env.monitor.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment