Skip to content

Instantly share code, notes, and snippets.

@Chachay
Last active February 18, 2021 16:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chachay/35452e3f25b3976ce8b5d107810d6124 to your computer and use it in GitHub Desktop.
Save Chachay/35452e3f25b3976ce8b5d107810d6124 to your computer and use it in GitHub Desktop.
How to use PFRL and MLFLOW together
import argparse
from distutils.version import LooseVersion
import functools
import glob
import os
import gym
import numpy as np
import torch
from torch import nn
from torch import distributions
import pfrl
from pfrl.agents import SoftActorCritic
from pfrl import experiments
from pfrl.nn.lmbda import Lambda
from pfrl import replay_buffers
from pfrl import utils
from pfrl.experiments.evaluation_hooks import EvaluationHook
import mlflow
class LogMLFlow(EvaluationHook):
def __init__(self):
self.support_train_agent_batch = True
self.support_train_agent = True
def __call__(self, env, agent, evaluator, step, eval_stats, agent_stats, env_stats):
mlflow.log_metric("R_mean", env_stats['mean'], step=step)
class NormalizeObsSpace(gym.ObservationWrapper):
"""Normalize a Box action space to [-1, 1]^n."""
def __init__(self, env):
super().__init__(env)
assert isinstance(env.observation_space, gym.spaces.Box)
self.observation_space = gym.spaces.Box(
low=-np.ones_like(env.observation_space.low),
high=np.ones_like(env.observation_space.low),
)
def observation(self, obs):
n_obs = obs.copy()
# -> [0, orig_high - orig_low]
n_obs -= self.env.observation_space.low
# -> [0, 2]
n_obs /= (self.env.observation_space.high - self.env.observation_space.low) / 2
# action is in [-1, 1]
return n_obs - 1
def main():
import logging
parser = argparse.ArgumentParser()
parser.add_argument(
"--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU."
)
parser.add_argument(
"--env",
type=str,
default="Pendulum-v0",
help="Gym Environment",
)
parser.add_argument(
"--num-envs", type=int, default=1, help="Number of envs run in parallel."
)
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 32)")
parser.add_argument(
"--outdir",
type=str,
default="results",
help=(
"Directory path to save output files."
" If it does not exist, it will be created."
),
)
parser.add_argument(
"--steps",
type=int,
default=10 ** 6,
help="Total number of timesteps to train the agent.",
)
parser.add_argument(
"--eval-interval",
type=int,
default=10000,
help="Interval in timesteps between evaluations.",
)
parser.add_argument(
"--eval-n-runs",
type=int,
default=10,
help="Number of episodes run for each evaluation.",
)
parser.add_argument(
"--render", action="store_true", help="Render env states in a GUI window."
)
parser.add_argument(
"--demo", action="store_true", help="Just run evaluation, not training."
)
parser.add_argument("--load-pretrained", action="store_true", default=False)
parser.add_argument(
"--load", type=str, default="", help="Directory to load agent from."
)
parser.add_argument(
"--log-level", type=int, default=logging.INFO, help="Level of the root logger."
)
parser.add_argument(
"--monitor", action="store_true", help="Wrap env with gym.wrappers.Monitor."
)
parser.add_argument(
"--log-interval",
type=int,
default=1000,
help="Interval in timesteps between outputting log messages during training",
)
parser.add_argument(
"--update-interval",
type=int,
default=2048,
help="Interval in timesteps between model updates.",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="Number of epochs to update model for per SAC iteration.",
)
parser.add_argument("--batch-size", type=int, default=64, help="Minibatch size")
parser.add_argument(
"--policy-output-scale",
type=float,
default=1.0,
help="Weight initialization scale of policy output.",
)
parser.add_argument(
"--replay-start-size",
type=int,
default=10000,
help="Minimum replay buffer size before " + "performing gradient updates.",
)
args = parser.parse_args()
logging.basicConfig(level=args.log_level)
# Set a random seed used in PFRL
utils.set_random_seed(args.seed)
# Set different random seeds for different subprocesses.
# If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
# If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
process_seeds = np.arange(args.num_envs) + args.seed * args.num_envs
assert process_seeds.max() < 2 ** 32
args.outdir = experiments.prepare_output_dir(args, args.outdir)
def make_env(process_idx, test):
env = gym.make(args.env)
# Use different random seeds for train and test envs
process_seed = int(process_seeds[process_idx])
env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
env.seed(env_seed)
# Cast observations to float32 because our model uses float32
env = pfrl.wrappers.CastObservationToFloat32(env)
env = NormalizeObsSpace(env)
env = pfrl.wrappers.NormalizeActionSpace(env)
if args.monitor:
env = pfrl.wrappers.Monitor(env, args.outdir)
if not test:
env = pfrl.wrappers.ScaleReward(env, 1e-3)
if args.render:
env = pfrl.wrappers.Render(env)
return env
def make_batch_env(test):
return pfrl.envs.MultiprocessVectorEnv(
[
functools.partial(make_env, idx, test)
for idx, env in enumerate(range(args.num_envs))
]
)
# Only for getting timesteps, and obs-action spaces
sample_env = gym.make(args.env)
timestep_limit = sample_env.spec.max_episode_steps
obs_space = sample_env.observation_space
action_space = sample_env.action_space
print("Observation space:", obs_space)
print("Action space:", action_space)
del sample_env
assert isinstance(action_space, gym.spaces.Box)
obs_size = obs_space.low.size
action_size = action_space.low.size
def squashed_diagonal_gaussian_head(x):
assert x.shape[-1] == action_size * 2
mean, log_scale = torch.chunk(x, 2, dim=1)
log_scale = torch.clamp(log_scale, -20.0, 2.0)
var = torch.exp(log_scale * 2)
base_distribution = distributions.Independent(
distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1
)
# cache_size=1 is required for numerical stability
return distributions.transformed_distribution.TransformedDistribution(
base_distribution, [distributions.transforms.TanhTransform(cache_size=1)]
)
policy = nn.Sequential(
nn.Linear(obs_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_size * 2),
Lambda(squashed_diagonal_gaussian_head),
)
torch.nn.init.xavier_uniform_(policy[0].weight)
torch.nn.init.xavier_uniform_(policy[2].weight)
torch.nn.init.xavier_uniform_(policy[4].weight, gain=args.policy_output_scale)
policy_optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)
def make_q_func_with_optimizer():
q_func = nn.Sequential(
pfrl.nn.ConcatObsAndAction(),
nn.Linear(obs_size + action_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
torch.nn.init.xavier_uniform_(q_func[1].weight)
torch.nn.init.xavier_uniform_(q_func[3].weight)
torch.nn.init.xavier_uniform_(q_func[5].weight)
q_func_optimizer = torch.optim.Adam(q_func.parameters(), lr=3e-4)
return q_func, q_func_optimizer
q_func1, q_func1_optimizer = make_q_func_with_optimizer()
q_func2, q_func2_optimizer = make_q_func_with_optimizer()
rbuf = replay_buffers.ReplayBuffer(10 ** 6)
def burnin_action_func():
"""Select random actions until model is updated one or more times."""
return np.random.uniform(action_space.low, action_space.high).astype(np.float32)
# Hyperparameters in http://arxiv.org/abs/1802.09477
agent = pfrl.agents.SoftActorCritic(
policy,
q_func1,
q_func2,
policy_optimizer,
q_func1_optimizer,
q_func2_optimizer,
rbuf,
gamma=0.99,
replay_start_size=args.replay_start_size,
gpu=args.gpu,
minibatch_size=args.batch_size,
burnin_action_func=burnin_action_func,
entropy_target=-action_size,
temperature_optimizer_lr=3e-4,
)
if args.load or args.load_pretrained:
if args.load_pretrained:
raise Exception("Pretrained models are currently unsupported.")
# either load or load_pretrained must be false
assert not args.load or not args.load_pretrained
if args.load:
agent.load(args.load)
else:
agent.load(utils.download_model("SAC", args.env, model_type="final")[0])
if args.demo:
env = make_env(0, True)
eval_stats = experiments.eval_performance(
env=env,
agent=agent,
n_steps=None,
n_episodes=args.eval_n_runs,
max_episode_len=timestep_limit,
)
print(
"n_runs: {} mean: {} median: {} stdev {}".format(
args.eval_n_runs,
eval_stats["mean"],
eval_stats["median"],
eval_stats["stdev"],
)
)
else:
existing_exp = mlflow.get_experiment_by_name(args.env)
if not existing_exp:
mlflow.create_experiment(args.env)
mlflow.set_experiment(args.env)
log_mlflow = LogMLFlow()
try:
mlflow.start_run():
mlflow.log_param("Algo", "SAC")
mlflow.log_artifacts(args.outdir)
mlflow.log_param("OutDir", args.outdir)
if args.num_envs==1:
experiments.train_agent_with_evaluation(
agent=agent,
env=make_env(0, False),
eval_env=make_env(0, True),
outdir=args.outdir,
steps=args.steps,
eval_n_steps=None,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
save_best_so_far_agent=True,
evaluation_hooks=(log_mlflow,),
)
else:
experiments.train_agent_batch_with_evaluation(
agent=agent,
env=make_batch_env(False),
eval_env=make_batch_env(True),
outdir=args.outdir,
steps=args.steps,
eval_n_steps=None,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
log_interval=args.log_interval,
max_episode_len=timestep_limit,
save_best_so_far_agent=True,
evaluation_hooks=(log_mlflow,)
)
finally:
mlflow.log_artifacts(args.outdir)
mlflow.end_run()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment