Skip to content

Instantly share code, notes, and snippets.

@ikeyasu
Last active February 16, 2019 14:48
Show Gist options
  • Save ikeyasu/ec5ff21265c20c0bc3a3bdda95fbb472 to your computer and use it in GitHub Desktop.
Save ikeyasu/ec5ff21265c20c0bc3a3bdda95fbb472 to your computer and use it in GitHub Desktop.
Amazon SageMaker Matsuri 20190212: ChainerRL x SageMaker
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"学習は、以下の通り実行します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker import get_execution_role\n",
"import sagemaker\n",
"\n",
"sagemaker_session = sagemaker.Session()\n",
"\n",
"# This role retrieves the SageMaker-compatible role used by this Notebook Instance.\n",
"role = get_execution_role()\n",
"\n",
"from sagemaker.chainer.estimator import Chainer\n",
"\n",
"rl_image = '520713654638.dkr.ecr.{}.amazonaws.com/sagemaker-rl-mxnet:coach0.11.0-cpu-py3'.format(sagemaker_session.boto_region_name)\n",
"local_mode = True\n",
"chainer_estimator = Chainer(entry_point='train_dqn_gym.py',\n",
" source_dir=\".\",\n",
" role=role,\n",
" image_name=rl_image,\n",
" sagemaker_session=None if local_mode else sagemaker_session,\n",
" train_instance_count=1,\n",
" train_instance_type='local' if local_mode else 'ml.m4.4xlarge',\n",
" hyperparameters={'gpu': -1}) \n",
"\n",
"chainer_estimator.fit(wait=local_mode)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"学習が終わったら、以下の通りダウンロードします。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"path = chainer_estimator.model_data.replace('s3://' + sagemaker_session.default_bucket() + '/', '')\n",
"s3 = boto3.resource('s3')\n",
"s3.Bucket(sagemaker_session.default_bucket()).download_file(path, 'model.tar.gz')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"結果の確認は、動画で行います。SageMaker RL のコンテナが必要なので、Chainer Estimator の、`image_name`引数にECRのイメージ名を指定します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!rm -rf model && mkdir model && cd model && tar xvzf ../model.tar.gz\n",
"\n",
"# This role retrieves the SageMaker-compatible role used by this Notebook Instance.\n",
"role = get_execution_role()\n",
"\n",
"from sagemaker.chainer.estimator import Chainer\n",
"\n",
"rl_image = '520713654638.dkr.ecr.{}.amazonaws.com/sagemaker-rl-mxnet:coach0.11.0-cpu-py3'.format(sagemaker_session.boto_region_name)\n",
"local_mode = True\n",
"chainer_estimator = Chainer(entry_point='train_dqn_gym.py',\n",
" source_dir=\".\",\n",
" role=role,\n",
" image_name=rl_image,\n",
" sagemaker_session=None if local_mode else sagemaker_session,\n",
" train_instance_count=1,\n",
" train_instance_type ='local' if local_mode else 'ml.m4.4xlarge',\n",
" hyperparameters={'gpu': -1, 'record-video': True, 'load': 'model/', 'steps': 100})\n",
"\n",
"chainer_estimator.fit(wait=local_mode)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"末尾に`WARNING:sagemaker.local.image:Failed to delete: /tmp/xxxx/algo-x-xxx` と表示されるので、それをカレントディレクトリにコピーしてください。\n",
"すると、左のフォルダ表示の以下に、mp4ファイルが格納されています。そのmp4ファイルをダウンロードしてください。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cp -r /tmp/xxxx/algo-x-xxx result1 # パスは書き換え"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_chainer_p36",
"language": "python",
"name": "conda_chainer_p36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
"""An example of training DQN against OpenAI Gym Envs.
This script is an example of training a DQN agent against OpenAI Gym envs.
Both discrete and continuous action spaces are supported. For continuous action
spaces, A NAF (Normalized Advantage Function) is used to approximate Q-values.
To solve CartPole-v0, run:
python train_dqn_gym.py --env CartPole-v0
To solve Pendulum-v0, run:
python train_dqn_gym.py --env Pendulum-v0
"""
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
import argparse
import os
import sys
def install(package):
if 'SM_OUTPUT_DATA_DIR' in os.environ:
os.system('pip3 install {}'.format(package))
install('chainerrl==0.4.0')
#for local
if not 'SM_OUTPUT_DATA_DIR' in os.environ:
os.environ['SM_OUTPUT_DATA_DIR'] = 'results'
os.environ['SM_MODEL_DIR'] = 'model'
os.environ['SM_CHANNEL_TRAIN'] = ''
os.environ['SM_CHANNEL_TEST'] = ''
if not 'SM_CHANNEL_TRAIN' in os.environ:
os.environ['SM_CHANNEL_TRAIN'] = ''
if not 'SM_CHANNEL_TEST' in os.environ:
os.environ['SM_CHANNEL_TEST'] = ''
import chainer
from chainer import optimizers
import chainerrl
from chainerrl.agents.dqn import DQN
from chainerrl import experiments
from chainerrl import explorers
from chainerrl import links
from chainerrl import misc
from chainerrl import q_functions
from chainerrl import replay_buffer
import gym
from gym import spaces
import gym.wrappers
import numpy as np
def main():
import logging
logging.basicConfig(level=logging.WARNING)
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0,
help='Random seed [0, 2 ** 32)')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--final-exploration-steps',
type=int, default=10 ** 4)
parser.add_argument('--start-epsilon', type=float, default=1.0)
parser.add_argument('--end-epsilon', type=float, default=0.1)
parser.add_argument('--noisy-net-sigma', type=float, default=None)
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--steps', type=int, default=10 ** 5)
parser.add_argument('--eval-steps', type=int, default=1000)
parser.add_argument('--prioritized-replay', action='store_true')
parser.add_argument('--episodic-replay', action='store_true')
parser.add_argument('--replay-start-size', type=int, default=1000)
parser.add_argument('--target-update-interval', type=int, default=10 ** 2)
parser.add_argument('--target-update-method', type=str, default='hard')
parser.add_argument('--soft-update-tau', type=float, default=1e-2)
parser.add_argument('--update-interval', type=int, default=1)
parser.add_argument('--eval-n-runs', type=int, default=100)
parser.add_argument('--eval-interval', type=int, default=10 ** 4)
parser.add_argument('--n-hidden-channels', type=int, default=100)
parser.add_argument('--n-hidden-layers', type=int, default=2)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--minibatch-size', type=int, default=None)
parser.add_argument('--reward-scale-factor', type=float, default=1e-3)
parser.add_argument('--render-train', action='store_true')
parser.add_argument('--render-eval', action='store_true')
parser.add_argument('--monitor', action='store_true')
parser.add_argument('--record-video', type=str, default=None) # for sagemaker
# Required for sagemaker
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
args = parser.parse_args()
# Sagemaker's hyperparameter cannot set switch flags.
if args.record_video == "True":
print("video recording mode.")
args.demo = True
args.monitor = True
# Set a random seed used in ChainerRL
misc.set_random_seed(args.seed, gpus=(args.gpu,))
args.output_data_dir = experiments.prepare_output_dir(
args, args.output_data_dir, argv=sys.argv)
print('Output files are saved in {}'.format(args.output_data_dir))
def clip_action_filter(a):
return np.clip(a, action_space.low, action_space.high)
def make_env(test):
env = gym.make(args.env)
# Use different random seeds for train and test envs
env_seed = 2 ** 32 - 1 - args.seed if test else args.seed
env.seed(env_seed)
if args.monitor:
env = gym.wrappers.Monitor(env, args.output_data_dir, force=True)
if isinstance(env.action_space, spaces.Box):
misc.env_modifiers.make_action_filtered(env, clip_action_filter)
if not test:
misc.env_modifiers.make_reward_filtered(
env, lambda x: x * args.reward_scale_factor)
if ((args.render_eval and test) or
(args.render_train and not test)):
misc.env_modifiers.make_rendered(env)
return env
env = make_env(test=False)
timestep_limit = env.spec.tags.get(
'wrapper_config.TimeLimit.max_episode_steps')
obs_space = env.observation_space
obs_size = obs_space.low.size
action_space = env.action_space
if isinstance(action_space, spaces.Box):
action_size = action_space.low.size
# Use NAF to apply DQN to continuous action spaces
q_func = q_functions.FCQuadraticStateQFunction(
obs_size, action_size,
n_hidden_channels=args.n_hidden_channels,
n_hidden_layers=args.n_hidden_layers,
action_space=action_space)
# Use the Ornstein-Uhlenbeck process for exploration
ou_sigma = (action_space.high - action_space.low) * 0.2
explorer = explorers.AdditiveOU(sigma=ou_sigma)
else:
n_actions = action_space.n
q_func = q_functions.FCStateQFunctionWithDiscreteAction(
obs_size, n_actions,
n_hidden_channels=args.n_hidden_channels,
n_hidden_layers=args.n_hidden_layers)
# Use epsilon-greedy for exploration
explorer = explorers.LinearDecayEpsilonGreedy(
args.start_epsilon, args.end_epsilon, args.final_exploration_steps,
action_space.sample)
if args.noisy_net_sigma is not None:
links.to_factorized_noisy(q_func)
# Turn off explorer
explorer = explorers.Greedy()
# Draw the computational graph and save it in the output directory.
chainerrl.misc.draw_computational_graph(
[q_func(np.zeros_like(obs_space.low, dtype=np.float32)[None])],
os.path.join(args.output_data_dir, 'model'))
opt = optimizers.Adam()
opt.setup(q_func)
rbuf_capacity = 5 * 10 ** 5
if args.episodic_replay:
if args.minibatch_size is None:
args.minibatch_size = 4
if args.prioritized_replay:
betasteps = (args.steps - args.replay_start_size) \
// args.update_interval
rbuf = replay_buffer.PrioritizedEpisodicReplayBuffer(
rbuf_capacity, betasteps=betasteps)
else:
rbuf = replay_buffer.EpisodicReplayBuffer(rbuf_capacity)
else:
if args.minibatch_size is None:
args.minibatch_size = 32
if args.prioritized_replay:
betasteps = (args.steps - args.replay_start_size) \
// args.update_interval
rbuf = replay_buffer.PrioritizedReplayBuffer(
rbuf_capacity, betasteps=betasteps)
else:
rbuf = replay_buffer.ReplayBuffer(rbuf_capacity)
def phi(obs):
return obs.astype(np.float32)
agent = DQN(q_func, opt, rbuf, gpu=args.gpu, gamma=args.gamma,
explorer=explorer, replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval,
update_interval=args.update_interval,
phi=phi, minibatch_size=args.minibatch_size,
target_update_method=args.target_update_method,
soft_update_tau=args.soft_update_tau,
episodic_update=args.episodic_replay, episodic_update_len=16)
if args.load:
agent.load(args.load)
eval_env = make_env(test=True)
def start_display():
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1024, 768))
display.start()
import os
os.environ["DISPLAY"] = ":" + str(display.display) + "." + str(display.screen)
if args.demo:
print('Starting demo mode.')
#start_display()
eval_stats = experiments.eval_performance(
env=eval_env,
agent=agent,
n_runs=args.eval_n_runs,
max_episode_len=timestep_limit)
print('n_runs: {} mean: {} median: {} stdev {}'.format(
args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
eval_stats['stdev']))
else:
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.output_data_dir, eval_env=eval_env,
max_episode_len=timestep_limit)
agent.save(args.model_dir)
if args.monitor:
eval_env.env.close()
eval_env.close()
env.close()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment