Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
import argparse
from datetime import datetime as dt
parser = argparse.ArgumentParser()
parser.add_argument('--train', action='store_true')
parser.add_argument('--load', type=str, default=None)
args = parser.parse_args()
env = gym.make('Breakout-v0')
class QFunction(chainer.Chain):
def __init__(self, obs_size, n_actions, n_hidden_channels=50):
super().__init__(
l0=L.Linear(obs_size, n_hidden_channels),
l1=L.Linear(n_hidden_channels, n_hidden_channels),
l2=L.Linear(n_hidden_channels, n_actions))
def __call__(self, x, test=False):
"""
Args:
x (ndarray or chainer.Variable): An observation
test (bool): a flag indicating whether it is in test mode
"""
h = F.tanh(self.l0(x))
h = F.tanh(self.l1(h))
return chainerrl.action_value.DiscreteActionValue(self.l2(h))
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)
gamma = 0.95
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
epsilon=0.3, random_action_func=env.action_space.sample)
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)
phi = lambda x: x.astype(np.float32, copy=False)
agent = chainerrl.agents.DoubleDQN(
q_func, optimizer, replay_buffer, gamma, explorer,
replay_start_size=500, update_frequency=1,
target_update_frequency=100, phi=phi)
if args.train :
chainerrl.experiments.train_agent_with_evaluation(
agent, env,
steps=2000,
eval_n_runs=10,
max_episode_len=1000,
eval_frequency=1000,
outdir='result/'+dt.now().strftime("%Y%m%d%H%M%S"))
if args.load :
agent.load(args.load)
obs = env.reset()
done = False
while not done:
env.render()
action = agent.act(obs)
obs, _, done, _ = env.step(action)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment