Created
December 25, 2022 05:59
-
-
Save Athe-kunal/7c421d0447aa0b7ac8091891fa7d5e93 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @Author: Astarag Mohapatra | |
import ray | |
assert ( | |
ray.__version__ > "2.0.0" | |
), "Please install ray 2.2.0 by doing 'pip install ray[rllib] ray[tune] lz4' , lz4 is for population based tuning" | |
from pprint import pprint | |
from ray import tune | |
from ray.tune.search import ConcurrencyLimiter | |
from ray.rllib.algorithms import Algorithm | |
from ray.tune import register_env | |
from ray.air import RunConfig, FailureConfig, ScalingConfig | |
from ray.tune.tune_config import TuneConfig | |
from ray.air.config import CheckpointConfig | |
import psutil | |
psutil_memory_in_bytes = psutil.virtual_memory().total | |
ray._private.utils.get_system_memory = lambda: psutil_memory_in_bytes | |
from typing import Dict, Optional, Any, List, Union | |
class DRLlibv2: | |
def __init__( | |
self, | |
trainable: Union[str, Any], | |
train_env, | |
train_env_name: str, | |
params: dict, | |
run_name: str = "tune_run", | |
framework: str = "torch", | |
local_dir: str = "tune_results", | |
num_workers: int = 1, | |
search_alg=None, | |
concurrent_trials: int = 0, | |
num_samples: int = 0, | |
scheduler=None, | |
log_level: str = "WARN", | |
num_gpus: Union[float, int] = 0, | |
num_cpus: Union[float, int] = 2, | |
dataframe_save: str = "tune.csv", | |
metric: str = "episode_reward_mean", | |
mode: Union[str, List[str]] = "max", | |
max_failures: int = 0, | |
training_iterations: int = 100, | |
checkpoint_num_to_keep: Union[None, int] = None, | |
checkpoint_freq: int = 0, | |
reuse_actors: bool = False, | |
): | |
register_env(train_env_name, lambda config: train_env) | |
self.params = params | |
self.params["framework"] = framework | |
self.params["log_level"] = log_level | |
self.params["num_gpus"] = num_gpus | |
self.params["num_workers"] = num_workers | |
self.params["env"] = train_env_name | |
self.run_name = run_name | |
self.local_dir = local_dir | |
self.search_alg = search_alg | |
if concurrent_trials != 0: | |
self.search_alg = ConcurrencyLimiter( | |
self.search_alg, max_concurrent=concurrent_trials | |
) | |
self.scheduler = scheduler | |
self.num_samples = num_samples | |
self.trainable = trainable | |
if isinstance(self.trainable, str): | |
self.trainable.upper() | |
self.num_cpus = num_cpus | |
self.num_gpus = num_gpus | |
self.dataframe_save = dataframe_save | |
self.metric = metric | |
self.mode = mode | |
self.max_failures = max_failures | |
self.training_iterations = training_iterations | |
self.checkpoint_freq = checkpoint_freq | |
self.checkpoint_num_to_keep = checkpoint_num_to_keep | |
self.reuse_actors = reuse_actors | |
def train_tune_model(self): | |
""" | |
Tuning and training the model | |
Returns the results object | |
""" | |
ray.init( | |
num_cpus=self.num_cpus, num_gpus=self.num_gpus, ignore_reinit_error=True | |
) | |
tuner = tune.Tuner( | |
self.trainable, | |
param_space=self.params, | |
tune_config=TuneConfig( | |
search_alg=self.search_alg, | |
num_samples=self.num_samples, | |
metric=self.metric, | |
mode=self.mode, | |
reuse_actors=self.reuse_actors, | |
), | |
run_config=RunConfig( | |
name=self.run_name, | |
local_dir=self.local_dir, | |
failure_config=FailureConfig( | |
max_failures=self.max_failures, fail_fast=False | |
), | |
stop={"training_iteration": self.training_iterations}, | |
checkpoint_config=CheckpointConfig( | |
num_to_keep=self.checkpoint_num_to_keep, | |
checkpoint_score_attribute=self.metric, | |
checkpoint_score_order=self.mode, | |
checkpoint_frequency=self.checkpoint_freq, | |
checkpoint_at_end=True, | |
), | |
verbose=3, | |
), | |
) | |
self.results = tuner.fit() | |
self.search_alg.save_to_dir(self.local_dir) | |
# ray.shutdown() | |
return self.results | |
def infer_results(self, to_dataframe: str = None, mode: str = "a"): | |
""" | |
Get tune results in a dataframe and best results object | |
""" | |
results_df = self.results.get_dataframe() | |
if to_dataframe is None: | |
to_dataframe = self.dataframe_save | |
results_df.to_csv(to_dataframe, mode=mode) | |
best_result = self.results.get_best_result() | |
# best_result = self.results.get_best_result() | |
# best_metric = best_result.metrics | |
# best_checkpoint = best_result.checkpoint | |
# best_trial_dir = best_result.log_dir | |
# results_df = self.results.get_dataframe() | |
return results_df, best_result | |
def get_test_agent(self, test_env, test_env_name: str, checkpoint=None): | |
""" | |
Get test agent | |
""" | |
register_env(test_env_name, lambda config: test_env) | |
if checkpoint is None: | |
checkpoint = self.results.get_best_result().checkpoint | |
testing_agent = Algorithm.from_checkpoint(checkpoint) | |
# testing_agent.config['env'] = test_env_name | |
return testing_agent |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment