Last active
August 15, 2022 22:22
-
-
Save ArturNiederfahrenhorst/8357f7277c32d12bd9367a8b36388f1e to your computer and use it in GitHub Desktop.
Training multiple policies in RLlib, reporting rewards separately @ Ray 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from pettingzoo.sisl import waterworld_v3 | |
import ray | |
from ray.tune import CLIReporter | |
from ray import air, tune | |
from ray.rllib.algorithms.ppo import PPOConfig | |
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv | |
from ray.tune.registry import register_env | |
from ray.rllib.algorithms.callbacks import DefaultCallbacks | |
if __name__ == '__main__': | |
ray.init() | |
def env_creator(args): | |
return PettingZooEnv(waterworld_v3.env()) | |
dummy_env = env_creator({}) | |
register_env("waterworld", env_creator) | |
obs_space = dummy_env.observation_space | |
act_space = dummy_env.action_space | |
config = PPOConfig() | |
config.multi_agent( | |
policies={pid: (None, obs_space, act_space, {}) for pid in | |
dummy_env.env.agents}, | |
policy_mapping_fn=(lambda agent_id, episode, **kwargs: agent_id), | |
) | |
config.rollouts(num_rollout_workers=4) | |
config.environment(env="waterworld") | |
config.callbacks(MyCallbacks) | |
config = config.to_dict() | |
tune.Tuner( | |
"PPO", | |
run_config=air.RunConfig( | |
stop={"episodes_total": 1}, | |
checkpoint_config=air.CheckpointConfig( | |
checkpoint_frequency=1000, | |
), | |
progress_reporter=CLIReporter( | |
metric_columns={ | |
"training_iteration": "training_iteration", | |
"time_total_s": "time_total_s", | |
"timesteps_total": "timesteps", | |
"episodes_this_iter": "episodes_trained", | |
"custom_metrics/policy_reward_mean/pursuer_0": "m_reward_p_0", | |
"custom_metrics/policy_reward_mean/pursuer_1": "m_reward_p_1", | |
"custom_metrics/policy_reward_mean/pursuer_2": "m_reward_p_2", | |
"custom_metrics/policy_reward_mean/pursuer_3": "m_reward_p_3", | |
"episode_reward_mean": "mean_reward_sum", | |
}, | |
sort_by_metric=True, | |
), | |
), | |
param_space=config, | |
).fit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment