Skip to content

Instantly share code, notes, and snippets.

@Athe-kunal
Created December 25, 2022 05:59
Show Gist options
  • Save Athe-kunal/7c421d0447aa0b7ac8091891fa7d5e93 to your computer and use it in GitHub Desktop.
Save Athe-kunal/7c421d0447aa0b7ac8091891fa7d5e93 to your computer and use it in GitHub Desktop.
# @Author: Astarag Mohapatra
import ray
assert (
ray.__version__ > "2.0.0"
), "Please install ray 2.2.0 by doing 'pip install ray[rllib] ray[tune] lz4' , lz4 is for population based tuning"
from pprint import pprint
from ray import tune
from ray.tune.search import ConcurrencyLimiter
from ray.rllib.algorithms import Algorithm
from ray.tune import register_env
from ray.air import RunConfig, FailureConfig, ScalingConfig
from ray.tune.tune_config import TuneConfig
from ray.air.config import CheckpointConfig
import psutil
psutil_memory_in_bytes = psutil.virtual_memory().total
ray._private.utils.get_system_memory = lambda: psutil_memory_in_bytes
from typing import Dict, Optional, Any, List, Union
class DRLlibv2:
def __init__(
self,
trainable: Union[str, Any],
train_env,
train_env_name: str,
params: dict,
run_name: str = "tune_run",
framework: str = "torch",
local_dir: str = "tune_results",
num_workers: int = 1,
search_alg=None,
concurrent_trials: int = 0,
num_samples: int = 0,
scheduler=None,
log_level: str = "WARN",
num_gpus: Union[float, int] = 0,
num_cpus: Union[float, int] = 2,
dataframe_save: str = "tune.csv",
metric: str = "episode_reward_mean",
mode: Union[str, List[str]] = "max",
max_failures: int = 0,
training_iterations: int = 100,
checkpoint_num_to_keep: Union[None, int] = None,
checkpoint_freq: int = 0,
reuse_actors: bool = False,
):
register_env(train_env_name, lambda config: train_env)
self.params = params
self.params["framework"] = framework
self.params["log_level"] = log_level
self.params["num_gpus"] = num_gpus
self.params["num_workers"] = num_workers
self.params["env"] = train_env_name
self.run_name = run_name
self.local_dir = local_dir
self.search_alg = search_alg
if concurrent_trials != 0:
self.search_alg = ConcurrencyLimiter(
self.search_alg, max_concurrent=concurrent_trials
)
self.scheduler = scheduler
self.num_samples = num_samples
self.trainable = trainable
if isinstance(self.trainable, str):
self.trainable.upper()
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.dataframe_save = dataframe_save
self.metric = metric
self.mode = mode
self.max_failures = max_failures
self.training_iterations = training_iterations
self.checkpoint_freq = checkpoint_freq
self.checkpoint_num_to_keep = checkpoint_num_to_keep
self.reuse_actors = reuse_actors
def train_tune_model(self):
"""
Tuning and training the model
Returns the results object
"""
ray.init(
num_cpus=self.num_cpus, num_gpus=self.num_gpus, ignore_reinit_error=True
)
tuner = tune.Tuner(
self.trainable,
param_space=self.params,
tune_config=TuneConfig(
search_alg=self.search_alg,
num_samples=self.num_samples,
metric=self.metric,
mode=self.mode,
reuse_actors=self.reuse_actors,
),
run_config=RunConfig(
name=self.run_name,
local_dir=self.local_dir,
failure_config=FailureConfig(
max_failures=self.max_failures, fail_fast=False
),
stop={"training_iteration": self.training_iterations},
checkpoint_config=CheckpointConfig(
num_to_keep=self.checkpoint_num_to_keep,
checkpoint_score_attribute=self.metric,
checkpoint_score_order=self.mode,
checkpoint_frequency=self.checkpoint_freq,
checkpoint_at_end=True,
),
verbose=3,
),
)
self.results = tuner.fit()
self.search_alg.save_to_dir(self.local_dir)
# ray.shutdown()
return self.results
def infer_results(self, to_dataframe: str = None, mode: str = "a"):
"""
Get tune results in a dataframe and best results object
"""
results_df = self.results.get_dataframe()
if to_dataframe is None:
to_dataframe = self.dataframe_save
results_df.to_csv(to_dataframe, mode=mode)
best_result = self.results.get_best_result()
# best_result = self.results.get_best_result()
# best_metric = best_result.metrics
# best_checkpoint = best_result.checkpoint
# best_trial_dir = best_result.log_dir
# results_df = self.results.get_dataframe()
return results_df, best_result
def get_test_agent(self, test_env, test_env_name: str, checkpoint=None):
"""
Get test agent
"""
register_env(test_env_name, lambda config: test_env)
if checkpoint is None:
checkpoint = self.results.get_best_result().checkpoint
testing_agent = Algorithm.from_checkpoint(checkpoint)
# testing_agent.config['env'] = test_env_name
return testing_agent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment