Skip to content

Instantly share code, notes, and snippets.

@chokosabe
Last active September 16, 2020 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chokosabe/23188725dc05b586cf1613bdebab5ef9 to your computer and use it in GitHub Desktop.
Save chokosabe/23188725dc05b586cf1613bdebab5ef9 to your computer and use it in GitHub Desktop.
A PPO train eval method
"""PPO Learner implementation."""
import gin
import tensorflow.compat.v2 as tf
from tf_agents.experimental.train import learner
from tf_agents.networks import utils
from tf_agents.utils import common
@gin.configurable
class PPOLearner(object):
"""Manages all the learning details needed when training an PPO.
These include:
* Using distribution strategies correctly
* Summaries
* Checkpoints
* Minimizing entering/exiting TF context:
Especially in the case of TPUs scheduling a single TPU program to
perform multiple train steps is critical for performance.
* Generalizes the train call to be done correctly across CPU, GPU, or TPU
executions managed by DistributionStrategies. This uses `strategy.run` and
then makes sure to do a reduce operation over the `LossInfo` returned by
the agent.
"""
def __init__(self,
root_dir,
train_step,
agent,
max_num_sequences=None,
minibatch_size=None,
shuffle_buffer_size=None,
after_train_strategy_step_fn=None,
triggers=None,
checkpoint_interval=100000,
summary_interval=1000,
use_kwargs_in_agent_train=False,
strategy=None):
"""Initializes a PPOLearner instance.
Args:
root_dir: Main directory path where checkpoints, saved_models, and
summaries will be written to.
train_step: a scalar tf.int64 `tf.Variable` which will keep track of the
number of train steps. This is used for artifacts created like
summaries, or outputs in the root_dir.
agent: `tf_agent.TFAgent` instance to train with.
max_num_sequences: The max number of sequences to read from the input
dataset in `run`. Defaults to None, in which case `run` will terminate
when reach the end of the dataset (for instance when the rate limiter
times out).
minibatch_size: The minibatch size. The dataset used for training is
shaped [minibatch_size, 1, ...].
shuffle_buffer_size: The buffer size for shuffling the trajectories before
splitting them into mini batches. Only required when mini batch
learning is enabled (minibatch_size is set). Otherwise it is ignored.
Commonly set to a number 1-3x the episode length of your environment.
after_train_strategy_step_fn: (Optional) callable of the form
`fn(sample, loss)` which can be used for example to update priorities in
a replay buffer where sample is pulled from the `experience_iterator`
and loss is a `LossInfo` named tuple returned from the agent. This is
called after every train step. It runs using `strategy.run(...)`.
triggers: List of callables of the form `trigger(train_step)`. After every
`run` call every trigger is called with the current `train_step` value
as an np scalar.
checkpoint_interval: Number of train steps in between checkpoints. Note
these are placed into triggers and so a check to generate a checkpoint
only occurs after every `run` call. Set to -1 to disable. This only
takes care of the checkpointing the training process. Policies must be
explicitly exported through triggers
summary_interval: Number of train steps in between summaries. Note these
are placed into triggers and so a check to generate a checkpoint only
occurs after every `run` call.
use_kwargs_in_agent_train: If True the experience from the replay buffer
is passed into the agent as kwargs. This requires samples from the RB to
be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This
is useful if you have an agent with a custom argspec.
strategy: (Optional) `tf.distribute.Strategy` to use during training.
"""
if minibatch_size is not None and shuffle_buffer_size is None:
raise ValueError(
'shuffle_buffer_size must be provided if minibatch_size is not None.'
)
if agent.update_normalizers_in_train:
raise ValueError(
'agent.update_normalizers_in_train should be set to False when '
'PPOLearner is used.'
)
self._agent = agent
self._max_num_sequences = max_num_sequences
self._minibatch_size = minibatch_size
self._shuffle_buffer_size = shuffle_buffer_size
self._generic_learner = learner.Learner(
root_dir,
train_step,
agent,
experience_dataset_fn=None,
after_train_strategy_step_fn=after_train_strategy_step_fn,
triggers=triggers,
checkpoint_interval=checkpoint_interval,
summary_interval=summary_interval,
use_kwargs_in_agent_train=use_kwargs_in_agent_train,
strategy=strategy)
def run(self, iterations, dataset):
"""Runs training until dataset timesout, or when num sequences is reached.
Args:
iterations: Number of iterations/epochs to repeat over the collected
sequences. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool
and 3 for Atari.
dataset: A 'tf.Dataset' where each sample is shaped
[sample_batch_size, sequence_length, ...], commonly the output from
'reverb_replay_buffer.as_dataset(sample_batch_size, preprocess_fn)'.
Returns:
The total loss computed before running the final step.
"""
# TODO(b/160802425): Verify this setup works with distributed.
if self._max_num_sequences:
dataset = dataset.take(self._max_num_sequences)
cached_dataset = dataset.cache()
self._update_advantage_normalizer(cached_dataset)
new_dataset = cached_dataset.repeat(iterations)
if self._minibatch_size:
def squash_dataset_element(sequence, info):
return tf.nest.map_structure(
utils.BatchSquash(2).flatten, (sequence, info))
# We unbatch the dataset shaped [B, T, ...] to a new dataset that contains
# individual elements.
# Note that we unbatch across the time dimension, which could result in
# mini batches that contain subsets from more than one sequences. The PPO
# agent can handle mini batches across episode boundaries.
new_dataset = new_dataset.map(squash_dataset_element).unbatch()
new_dataset = new_dataset.shuffle(self._shuffle_buffer_size)
new_dataset = new_dataset.batch(1, drop_remainder=True)
new_dataset = new_dataset.batch(self._minibatch_size, drop_remainder=True)
# TODO(b/161133726): use learner.run once it supports None iterations.
def _summary_record_if():
return tf.math.equal(
self._generic_learner.train_step %
tf.constant(self._generic_learner.summary_interval), 0)
with self._generic_learner.train_summary_writer.as_default(), \
common.soft_device_placement(), \
tf.compat.v2.summary.record_if(_summary_record_if), \
self._generic_learner.strategy.scope():
loss_info = self.multi_train_step(iter(new_dataset))
train_step_val = self._generic_learner.train_step_numpy
for trigger in self._generic_learner.triggers:
trigger(train_step_val)
self._update_normalizers(cached_dataset)
return loss_info
@common.function(autograph=True)
def multi_train_step(self, iterator):
experience, sample_info = next(iterator)
loss_info = self.single_train_step(experience, sample_info)
for experience, sample_info in iterator:
loss_info = self.single_train_step(experience, sample_info)
return loss_info
@common.function(autograph=False)
def single_train_step(self, experience, sample_info):
"""Train a single (mini) batch of Trajectories."""
if self._generic_learner.use_kwargs_in_agent_train:
loss_info = self._generic_learner.strategy.run(
self._agent.train, kwargs=experience)
else:
loss_info = self._generic_learner.strategy.run(
self._agent.train, args=(experience,))
if self._generic_learner.after_train_strategy_step_fn:
if self.use_kwargs_in_agent_train:
self.strategy.run(
self._generic_learner.after_train_strategy_step_fn,
kwargs=dict(
experience=(experience, sample_info), loss_info=loss_info))
else:
self.strategy.run(
self._generic_learner.after_train_strategy_step_fn,
args=((experience, sample_info), loss_info))
return loss_info
@common.function(autograph=True)
def _update_normalizers(self, dataset):
iterator = iter(dataset)
traj, _ = next(iterator)
self._agent.update_observation_normalizer(traj.observation)
self._agent.update_reward_normalizer(traj.reward)
for traj, _ in iterator:
self._agent.update_observation_normalizer(traj.observation)
self._agent.update_reward_normalizer(traj.reward)
@common.function(autograph=True)
def _update_advantage_normalizer(self, dataset):
self._agent._reset_advantage_normalizer() # pylint: disable=protected-access
iterator = iter(dataset)
traj, _ = next(iterator)
self._agent._update_advantage_normalizer(traj.policy_info['advantage']) # pylint: disable=protected-access
for traj, _ in iterator:
self._agent._update_advantage_normalizer(traj.policy_info['advantage']) # pylint: disable=protected-access
@property
def train_step_numpy(self):
"""The current train_step.
Returns:
The current `train_step`. Note this will return a scalar numpy array which
holds the `train_step` value when this was called.
"""
return self._generic_learner.train_step_numpy
"""Train and Eval PPOClipAgent in the Mujoco environments.
All hyperparameters come from the PPO paper
https://arxiv.org/abs/1707.06347.pdf
"""
import os
from absl import logging
import gin
import reverb
import tensorflow.compat.v2 as tf
from tf_agents.agents.ppo import ppo_clip_agent
from tf_agents.environments import suite_mujoco
#import .ppo_learner
from tf_agents.experimental.train import actor
from tf_agents.experimental.train import learner
from tf_agents.experimental.train import triggers
from tf_agents.experimental.train.utils import spec_utils
from tf_agents.experimental.train.utils import train_utils
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.metrics import py_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import value_network
from tf_agents.policies import py_tf_eager_policy
actor_fc_layers=(64, 64)
value_fc_layers=(64, 64)
@gin.configurable
def train_eval(
root_dir,
env_name='Hedge',
# Training params
num_iterations=20000,
actor_fc_layers=actor_fc_layers,
value_fc_layers=value_fc_layers,
learning_rate=1e-5,
collect_sequence_length=2048,
minibatch_size=64,
num_epochs=10,
# Agent params
importance_ratio_clipping=0.2,
lambda_value=0.95,
discount_factor=0.99,
entropy_regularization=0.,
value_pred_loss_coef=0.5,
use_gae=True,
use_td_lambda_return=True,
gradient_clipping=None,
value_clipping=None,
# Replay params
reverb_port=None,
replay_capacity=10000,
# Others
policy_save_interval=5000,
summary_interval=1000,
eval_interval=10000,
eval_episodes=30,
debug_summaries=False,
summarize_grads_and_vars=False,
env=None,
):
"""Trains and evaluates PPO (Importance Ratio Clipping).
Args:
root_dir: Main directory path where checkpoints, saved_models, and summaries
will be written to.
env_name: Name for the Mujoco environment to load.
num_iterations: The number of iterations to perform collection and training.
actor_fc_layers: List of fully_connected parameters for the actor network,
where each item is the number of units in the layer.
value_fc_layers: : List of fully_connected parameters for the value network,
where each item is the number of units in the layer.
learning_rate: Learning rate used on the Adam optimizer.
collect_sequence_length: Number of steps to take in each collect run.
minibatch_size: Number of elements in each mini batch. If `None`, the entire
collected sequence will be treated as one batch.
num_epochs: Number of iterations to repeat over all collected data per data
collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for
Roboschool and 3 for Atari.
importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For
more detail, see explanation at the top of the doc.
lambda_value: Lambda parameter for TD-lambda computation.
discount_factor: Discount factor for return computation. Default to `0.99`
which is the value used for all environments from (Schulman, 2017).
entropy_regularization: Coefficient for entropy regularization loss term.
Default to `0.0` because no entropy bonus was used in (Schulman, 2017).
value_pred_loss_coef: Multiplier for value prediction loss to balance with
policy gradient loss. Default to `0.5`, which was used for all
environments in the OpenAI baseline implementation. This parameters is
irrelevant unless you are sharing part of actor_net and value_net. In that
case, you would want to tune this coeeficient, whose value depends on the
network architecture of your choice.
use_gae: If True (default False), uses generalized advantage estimation for
computing per-timestep advantage. Else, just subtracts value predictions
from empirical return.
use_td_lambda_return: If True (default False), uses td_lambda_return for
training value function; here: `td_lambda_return = gae_advantage +
value_predictions`. `use_gae` must be set to `True` as well to enable TD
-lambda returns. If `use_td_lambda_return` is set to True while
`use_gae` is False, the empirical return will be used and a warning will
be logged.
gradient_clipping: Norm length to clip gradients.
value_clipping: Difference between new and old value predictions are clipped
to this threshold. Value clipping could be helpful when training
very deep networks. Default: no clipping.
reverb_port: Port for reverb server, if None, use a randomly chosen unused
port.
replay_capacity: The maximum number of elements for the replay buffer. Items
will be wasted if this is smalled than collect_sequence_length.
policy_save_interval: How often, in train_steps, the policy will be saved.
summary_interval: How often to write data into Tensorboard.
eval_interval: How often to run evaluation, in train_steps.
eval_episodes: Number of episodes to evaluate over.
debug_summaries: Boolean for whether to gather debug summaries.
summarize_grads_and_vars: If true, gradient summaries will be written.
"""
collect_env = env
eval_env = env
num_environments = 1
observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
spec_utils.get_tensor_specs(collect_env))
actor_net = actor_distribution_network.ActorDistributionNetwork(
observation_tensor_spec,
action_tensor_spec,
fc_layer_params=actor_fc_layers,
activation_fn=tf.nn.tanh,
kernel_initializer=tf.keras.initializers.Orthogonal())
value_net = value_network.ValueNetwork(
observation_tensor_spec,
fc_layer_params=value_fc_layers,
kernel_initializer=tf.keras.initializers.Orthogonal())
train_step = train_utils.create_train_step()
current_iteration = tf.Variable(0, dtype=tf.int64)
def learning_rate_fn():
# Linearly decay the learning rate.
return learning_rate * (1 - current_iteration / num_iterations)
agent = ppo_clip_agent.PPOClipAgent(
time_step_tensor_spec,
action_tensor_spec,
optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=learning_rate_fn, epsilon=1e-5),
actor_net=actor_net,
value_net=value_net,
importance_ratio_clipping=importance_ratio_clipping,
lambda_value=lambda_value,
discount_factor=discount_factor,
entropy_regularization=entropy_regularization,
value_pred_loss_coef=value_pred_loss_coef,
# This is a legacy argument for the number of times we repeat the data
# inside of the train function, incompatible with mini batch learning.
# We set the epoch number from the replay buffer and tf.Data instead.
num_epochs=1,
use_gae=use_gae,
use_td_lambda_return=use_td_lambda_return,
gradient_clipping=gradient_clipping,
value_clipping=value_clipping,
# TODO(b/150244758): Default compute_value_and_advantage_in_train to False
# after Reverb open source.
compute_value_and_advantage_in_train=False,
# Skips updating normalizers in the agent, as it's handled in the learner.
update_normalizers_in_train=False,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=train_step)
agent.initialize()
table_name = 'uniform_table'
table = reverb.Table(
table_name,
max_size=replay_capacity,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
max_times_sampled=1)
reverb_server = reverb.Server([table], port=reverb_port)
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
agent.collect_data_spec,
sequence_length=collect_sequence_length,
table_name=table_name,
server_address='localhost:{}'.format(reverb_server.port),
# The only collected sequence is used to populate the batches.
max_cycle_length=1,
rate_limiter_timeout_ms=1000)
# TODO(b/162244134): move to using the episodic observer after the performance
# issue caused by the bug is resolved.
rb_observer = reverb_utils.ReverbAddTrajectoryObserver( # pylint: disable=protected-access
reverb_replay.py_client, table_name,
sequence_length=collect_sequence_length,
stride_length=collect_sequence_length,
#allow_multi_episode_sequences=True
)
saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
collect_env_step_metric = py_metrics.EnvironmentSteps()
learning_triggers = [
triggers.PolicySavedModelTrigger(
saved_model_dir,
agent,
train_step,
interval=policy_save_interval,
metadata_metrics={
triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric
}),
triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),
]
agent_learner = PPOLearner(
root_dir,
train_step,
agent,
minibatch_size=minibatch_size,
shuffle_buffer_size=collect_sequence_length,
triggers=learning_triggers)
tf_collect_policy = agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
tf_collect_policy, use_tf_function=True)
collect_actor = actor.Actor(
collect_env,
collect_policy,
train_step,
steps_per_run=collect_sequence_length,
observers=[rb_observer],
metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],
reference_metrics=[collect_env_step_metric],
summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
summary_interval=summary_interval)
tf_greedy_policy = agent.policy
greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
tf_greedy_policy, use_tf_function=True)
if eval_interval:
logging.info('Intial evaluation.')
eval_actor = actor.Actor(
eval_env,
greedy_policy,
train_step,
metrics=actor.eval_metrics(eval_episodes),
summary_dir=os.path.join(root_dir, 'eval'),
episodes_per_run=eval_episodes)
eval_actor.run_and_log()
logging.info('Training.')
dataset = reverb_replay.as_dataset(
sample_batch_size=num_environments,
sequence_preprocess_fn=agent.preprocess_sequence)
for _ in range(num_iterations):
collect_actor.run()
# TODO(b/159490625): Get rid of the reset call once the
# multi_episode_sequences flag is gone.
# TODO(b/159615593): Update to use observer.flush.
# Reset the reverb observer to make sure the data collected is flushed and
# written to the RB.
rb_observer.reset()
agent_learner.run(iterations=num_epochs, dataset=dataset)
reverb_replay.clear()
current_iteration.assign_add(1)
if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
logging.info('Evaluating.')
eval_actor.run_and_log()
rb_observer.close()
reverb_server.stop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment