Skip to content

Instantly share code, notes, and snippets.

@rlan
Created April 27, 2021 04:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rlan/1b050a95536913f632742d72a84ae382 to your computer and use it in GitHub Desktop.
Save rlan/1b050a95536913f632742d72a84ae382 to your computer and use it in GitHub Desktop.
Merge callbacks for any number of DefaultCallbacks classes in RLlib
from typing import Dict, List, Optional, TYPE_CHECKING
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import AgentID, PolicyID
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@PublicAPI
def merge_callbacks(callbacks: List) -> DefaultCallbacks:
"""Merge a list of callbacks.
Each of the callback function will be called in the order given.
class FirstCallbacks(DefaultCallbacks):
#etc
class SecondCallbacks(DefaultCallbacks):
#etc
config = {
"env" : Env,
"lr" : 1e-4,
"callbacks": merge_callbacks( [FirstCallbacks(), SecondCallbacks()] ),
}
"""
class ListOfCallbacks(DefaultCallbacks):
def __init__(self):
super().__init__()
self.callbacks = callbacks
def on_episode_start(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_episode_start", None)
if callable(func):
func(worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs)
def on_episode_step(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_episode_step", None)
if callable(func):
func(worker=worker,
base_env=base_env,
episode=episode,
env_index=env_index,
**kwargs)
def on_episode_end(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_episode_end", None)
if callable(func):
func(worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs)
def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch],
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_postprocess_trajectory", None)
if callable(func):
func(worker=worker,
episode=episode,
agent_id=agent_id,
policy_id=policy_id,
policies=policies,
postprocessed_batch=postprocessed_batch,
original_batches=original_batches,
**kwargs)
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_sample_end", None)
if callable(func):
func(worker=worker,
samples=samples,
**kwargs)
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_learn_on_batch", None)
if callable(func):
func(policy=policy,
train_batch=train_batch,
**kwargs)
def on_train_result(self, *, trainer, result: dict,
**kwargs) -> None:
for cb in self.callbacks:
func = getattr(cb, "on_train_result", None)
if callable(func):
func(trainer=trainer,
result=result,
**kwargs)
return ListOfCallbacks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment