Created
April 27, 2021 04:14
-
-
Save rlan/1b050a95536913f632742d72a84ae382 to your computer and use it in GitHub Desktop.
Merge callbacks for any number of DefaultCallbacks classes in RLlib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Optional, TYPE_CHECKING | |
from ray.rllib.agents.callbacks import DefaultCallbacks | |
from ray.rllib.env import BaseEnv | |
from ray.rllib.policy import Policy | |
from ray.rllib.policy.sample_batch import SampleBatch | |
from ray.rllib.evaluation import MultiAgentEpisode | |
from ray.rllib.utils.annotations import PublicAPI | |
from ray.rllib.utils.deprecation import deprecation_warning | |
from ray.rllib.utils.typing import AgentID, PolicyID | |
from ray.util.debug import log_once | |
if TYPE_CHECKING: | |
from ray.rllib.evaluation import RolloutWorker | |
@PublicAPI | |
def merge_callbacks(callbacks: List) -> DefaultCallbacks: | |
"""Merge a list of callbacks. | |
Each of the callback function will be called in the order given. | |
class FirstCallbacks(DefaultCallbacks): | |
#etc | |
class SecondCallbacks(DefaultCallbacks): | |
#etc | |
config = { | |
"env" : Env, | |
"lr" : 1e-4, | |
"callbacks": merge_callbacks( [FirstCallbacks(), SecondCallbacks()] ), | |
} | |
""" | |
class ListOfCallbacks(DefaultCallbacks): | |
def __init__(self): | |
super().__init__() | |
self.callbacks = callbacks | |
def on_episode_start(self, | |
*, | |
worker: "RolloutWorker", | |
base_env: BaseEnv, | |
policies: Dict[PolicyID, Policy], | |
episode: MultiAgentEpisode, | |
env_index: Optional[int] = None, | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_episode_start", None) | |
if callable(func): | |
func(worker=worker, | |
base_env=base_env, | |
policies=policies, | |
episode=episode, | |
env_index=env_index, | |
**kwargs) | |
def on_episode_step(self, | |
*, | |
worker: "RolloutWorker", | |
base_env: BaseEnv, | |
episode: MultiAgentEpisode, | |
env_index: Optional[int] = None, | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_episode_step", None) | |
if callable(func): | |
func(worker=worker, | |
base_env=base_env, | |
episode=episode, | |
env_index=env_index, | |
**kwargs) | |
def on_episode_end(self, | |
*, | |
worker: "RolloutWorker", | |
base_env: BaseEnv, | |
policies: Dict[PolicyID, Policy], | |
episode: MultiAgentEpisode, | |
env_index: Optional[int] = None, | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_episode_end", None) | |
if callable(func): | |
func(worker=worker, | |
base_env=base_env, | |
policies=policies, | |
episode=episode, | |
env_index=env_index, | |
**kwargs) | |
def on_postprocess_trajectory( | |
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, | |
agent_id: AgentID, policy_id: PolicyID, | |
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, | |
original_batches: Dict[AgentID, SampleBatch], | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_postprocess_trajectory", None) | |
if callable(func): | |
func(worker=worker, | |
episode=episode, | |
agent_id=agent_id, | |
policy_id=policy_id, | |
policies=policies, | |
postprocessed_batch=postprocessed_batch, | |
original_batches=original_batches, | |
**kwargs) | |
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_sample_end", None) | |
if callable(func): | |
func(worker=worker, | |
samples=samples, | |
**kwargs) | |
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_learn_on_batch", None) | |
if callable(func): | |
func(policy=policy, | |
train_batch=train_batch, | |
**kwargs) | |
def on_train_result(self, *, trainer, result: dict, | |
**kwargs) -> None: | |
for cb in self.callbacks: | |
func = getattr(cb, "on_train_result", None) | |
if callable(func): | |
func(trainer=trainer, | |
result=result, | |
**kwargs) | |
return ListOfCallbacks |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment