Skip to content

Instantly share code, notes, and snippets.

@mitmul
Last active October 11, 2017 13:42
Show Gist options
  • Save mitmul/26d36d25393631af5f229be5c85e8773 to your computer and use it in GitHub Desktop.
Save mitmul/26d36d25393631af5f229be5c85e8773 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
import chainerrl
import gym
class QFunction(chainer.Chain):
def __init__(self, n_in, n_out, n_hidden=100):
super().__init__()
with self.init_scope():
self.l1 = L.Linear(n_in, n_hidden)
self.l2 = L.Linear(n_hidden, n_hidden)
self.l3 = L.Linear(n_hidden, n_out)
def __call__(self, x):
h = F.relu(self.l1(x))
h = F.relu(self.l2(h))
return chainerrl.action_value.DiscreteActionValue(self.l3(h))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--start_epsilon', type=float, default=1.0)
parser.add_argument('--end_epsilon', type=float, default=0.1)
parser.add_argument('--replay_start', type=int, default=1000)
parser.add_argument('--batchsize', type=int, default=32)
parser.add_argument('--update_interval', type=int, default=1)
parser.add_argument('--target_update_interval', type=int, default=10000)
parser.add_argument('--clip_delta', type=int, default=1)
parser.add_argument('--render', action='store_true', default=False)
parser.add_argument('--n_episodes', type=int, default=1500)
parser.add_argument('--max_episode_len', type=int, default=500)
args = parser.parse_args()
env = gym.make('CartPole-v0')
obs = env.reset()
n_in = np.prod(env.observation_space.shape)
n_actions = env.action_space.n
q_func = QFunction(n_in, n_actions)
optimizer = chainer.optimizers.Adam()
optimizer.setup(q_func)
explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
args.start_epsilon, args.end_epsilon,
args.n_episodes * args.max_episode_len * 0.1, env.action_space.sample)
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)
agent = chainerrl.agents.DQN(
q_func, optimizer, replay_buffer, args.gamma, explorer, args.gpu,
args.replay_start, args.batchsize, args.update_interval,
args.target_update_interval, bool(args.clip_delta),
phi=lambda x: x.astype(np.float32, copy=False)
)
for i in range(1, args.n_episodes + 1):
obs = env.reset()
reward = 0
done = False
R = 0
t = 0
while not done and t < args.max_episode_len:
if args.render:
env.render()
action = agent.act_and_train(obs, reward)
obs, reward, done, _ = env.step(action)
R += reward
t += 1
print('episode: {}\t R:{}\t stats:{}'.format(
i, R, agent.get_statistics()))
agent.stop_episode_and_train(obs, reward, done)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment