Created
January 17, 2021 13:51
-
-
Save ikeyasu/67b72a66e747fb06f685abc3eed1a10f to your computer and use it in GitHub Desktop.
pfrl_train_rainbow_mario.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import gym | |
import numpy as np | |
import torch | |
import pfrl | |
from pfrl import agents, experiments, explorers | |
from pfrl import nn as pnn | |
from pfrl import replay_buffers, utils | |
from pfrl.q_functions import DistributionalDuelingDQN | |
from pfrl.wrappers import atari_wrappers | |
from nes_py.wrappers import JoypadSpace | |
import gym_super_mario_bros | |
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--env", type=str, default="SuperMarioBros-v0") | |
parser.add_argument( | |
"--outdir", | |
type=str, | |
default="results", | |
help=( | |
"Directory path to save output files." | |
" If it does not exist, it will be created." | |
), | |
) | |
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") | |
parser.add_argument("--gpu", type=int, default=0) | |
parser.add_argument("--demo", action="store_true", default=False) | |
parser.add_argument("--load-pretrained", action="store_true", default=False) | |
parser.add_argument( | |
"--pretrained-type", type=str, default="best", choices=["best", "final"] | |
) | |
parser.add_argument("--load", type=str, default=None) | |
parser.add_argument("--eval-epsilon", type=float, default=0.0) | |
parser.add_argument("--noisy-net-sigma", type=float, default=0.5) | |
parser.add_argument("--steps", type=int, default=5 * 10 ** 7) | |
parser.add_argument( | |
"--max-frames", | |
type=int, | |
default=30 * 60 * 60, # 30 minutes with 60 fps | |
help="Maximum number of frames for each episode.", | |
) | |
parser.add_argument("--replay-start-size", type=int, default=2 * 10 ** 4) | |
parser.add_argument("--eval-n-steps", type=int, default=125000) | |
parser.add_argument("--eval-interval", type=int, default=250000) | |
parser.add_argument( | |
"--log-level", | |
type=int, | |
default=20, | |
help="Logging level. 10:DEBUG, 20:INFO etc.", | |
) | |
parser.add_argument( | |
"--render", | |
action="store_true", | |
default=False, | |
help="Render env states in a GUI window.", | |
) | |
parser.add_argument( | |
"--monitor", | |
action="store_true", | |
default=False, | |
help=( | |
"Monitor env. Videos and additional information are saved as output files." | |
), | |
) | |
parser.add_argument("--n-best-episodes", type=int, default=200) | |
args = parser.parse_args() | |
import logging | |
logging.basicConfig(level=args.log_level) | |
# Set a random seed used in PFRL. | |
utils.set_random_seed(args.seed) | |
# Set different random seeds for train and test envs. | |
train_seed = args.seed | |
test_seed = 2 ** 31 - 1 - args.seed | |
args.outdir = experiments.prepare_output_dir(args, args.outdir) | |
print("Output files are saved in {}".format(args.outdir)) | |
def make_atari(env_id, max_frames=30 * 60 * 60): | |
env = gym.make(env_id) | |
assert isinstance(env, gym.wrappers.TimeLimit) | |
# Unwrap TimeLimit wrapper because we use our own time limits | |
env = env.env | |
if max_frames: | |
env = pfrl.wrappers.ContinuingTimeLimit(env, max_episode_steps=max_frames) | |
env = pfrl.wrappers.atari_wrappers.NoopResetEnv(env, noop_max=30) | |
env = pfrl.wrappers.atari_wrappers.MaxAndSkipEnv(env, skip=4) | |
env = JoypadSpace(env, SIMPLE_MOVEMENT) | |
return env | |
def make_env(test): | |
# List of mario bros env: https://pypi.org/project/gym-super-mario-bros/ | |
assert "SuperMarioBros" in args.env | |
# Use different random seeds for train and test envs | |
env_seed = test_seed if test else train_seed | |
env = atari_wrappers.wrap_deepmind( | |
make_atari(args.env, max_frames=args.max_frames), | |
episode_life=False, # SuperMarioBros is always episodic | |
clip_rewards=not test, | |
) | |
env.seed(int(env_seed)) | |
if test: | |
# Randomize actions like epsilon-greedy in evaluation as well | |
env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon) | |
if args.monitor: | |
env = pfrl.wrappers.Monitor( | |
env, args.outdir, mode="evaluation" if test else "training" | |
) | |
if args.render: | |
env = pfrl.wrappers.Render(env) | |
return env | |
env = make_env(test=False) | |
eval_env = make_env(test=True) | |
n_actions = env.action_space.n | |
n_atoms = 51 | |
v_max = 10 | |
v_min = -10 | |
q_func = DistributionalDuelingDQN( | |
n_actions, | |
n_atoms, | |
v_min, | |
v_max, | |
) | |
# Noisy nets | |
pnn.to_factorized_noisy(q_func, sigma_scale=args.noisy_net_sigma) | |
# Turn off explorer | |
explorer = explorers.Greedy() | |
# Use the same hyper parameters as https://arxiv.org/abs/1710.02298 | |
opt = torch.optim.Adam(q_func.parameters(), 6.25e-5, eps=1.5 * 10 ** -4) | |
# Prioritized Replay | |
# Anneal beta from beta0 to 1 throughout training | |
update_interval = 4 | |
betasteps = args.steps / update_interval | |
rbuf = replay_buffers.PrioritizedReplayBuffer( | |
10 ** 6, | |
alpha=0.5, | |
beta0=0.4, | |
betasteps=betasteps, | |
num_steps=3, | |
normalize_by_max="memory", | |
) | |
def phi(x): | |
# Feature extractor | |
return np.asarray(x, dtype=np.float32) / 255 | |
Agent = agents.CategoricalDoubleDQN | |
agent = Agent( | |
q_func, | |
opt, | |
rbuf, | |
gpu=args.gpu, | |
gamma=0.99, | |
explorer=explorer, | |
minibatch_size=32, | |
replay_start_size=args.replay_start_size, | |
target_update_interval=32000, | |
update_interval=update_interval, | |
batch_accumulator="mean", | |
phi=phi, | |
) | |
if args.load or args.load_pretrained: | |
# either load_ or load_pretrained must be false | |
assert not args.load or not args.load_pretrained | |
if args.load: | |
agent.load(args.load) | |
else: | |
agent.load( | |
utils.download_model( | |
"Rainbow", args.env, model_type=args.pretrained_type | |
)[0] | |
) | |
if args.demo: | |
eval_stats = experiments.eval_performance( | |
env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None | |
) | |
print( | |
"n_episodes: {} mean: {} median: {} stdev {}".format( | |
eval_stats["episodes"], | |
eval_stats["mean"], | |
eval_stats["median"], | |
eval_stats["stdev"], | |
) | |
) | |
else: | |
experiments.train_agent_with_evaluation( | |
agent=agent, | |
env=env, | |
steps=args.steps, | |
eval_n_steps=args.eval_n_steps, | |
eval_n_episodes=None, | |
eval_interval=args.eval_interval, | |
outdir=args.outdir, | |
save_best_so_far_agent=True, | |
eval_env=eval_env, | |
) | |
dir_of_best_network = os.path.join(args.outdir, "best") | |
agent.load(dir_of_best_network) | |
# run 200 evaluation episodes, each capped at 30 mins of play | |
stats = experiments.evaluator.eval_performance( | |
env=eval_env, | |
agent=agent, | |
n_steps=None, | |
n_episodes=args.n_best_episodes, | |
max_episode_len=args.max_frames / 4, | |
logger=None, | |
) | |
with open(os.path.join(args.outdir, "bestscores.json"), "w") as f: | |
json.dump(stats, f) | |
print("The results of the best scoring network:") | |
for stat in stats: | |
print(str(stat) + ":" + str(stats[stat])) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment