Skip to content

Instantly share code, notes, and snippets.

@ArturNiederfahrenhorst
Last active August 15, 2022 22:22
Show Gist options
  • Save ArturNiederfahrenhorst/8357f7277c32d12bd9367a8b36388f1e to your computer and use it in GitHub Desktop.
Save ArturNiederfahrenhorst/8357f7277c32d12bd9367a8b36388f1e to your computer and use it in GitHub Desktop.
Training multiple policies in RLlib, reporting rewards separately @ Ray 2.0
import numpy as np
from pettingzoo.sisl import waterworld_v3
import ray
from ray.tune import CLIReporter
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.tune.registry import register_env
from ray.rllib.algorithms.callbacks import DefaultCallbacks
if __name__ == '__main__':
ray.init()
def env_creator(args):
return PettingZooEnv(waterworld_v3.env())
dummy_env = env_creator({})
register_env("waterworld", env_creator)
obs_space = dummy_env.observation_space
act_space = dummy_env.action_space
config = PPOConfig()
config.multi_agent(
policies={pid: (None, obs_space, act_space, {}) for pid in
dummy_env.env.agents},
policy_mapping_fn=(lambda agent_id, episode, **kwargs: agent_id),
)
config.rollouts(num_rollout_workers=4)
config.environment(env="waterworld")
config.callbacks(MyCallbacks)
config = config.to_dict()
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"episodes_total": 1},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=1000,
),
progress_reporter=CLIReporter(
metric_columns={
"training_iteration": "training_iteration",
"time_total_s": "time_total_s",
"timesteps_total": "timesteps",
"episodes_this_iter": "episodes_trained",
"custom_metrics/policy_reward_mean/pursuer_0": "m_reward_p_0",
"custom_metrics/policy_reward_mean/pursuer_1": "m_reward_p_1",
"custom_metrics/policy_reward_mean/pursuer_2": "m_reward_p_2",
"custom_metrics/policy_reward_mean/pursuer_3": "m_reward_p_3",
"episode_reward_mean": "mean_reward_sum",
},
sort_by_metric=True,
),
),
param_space=config,
).fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment