Skip to content

Instantly share code, notes, and snippets.

@rojas-diego
Last active July 7, 2020 15:09
Show Gist options
  • Save rojas-diego/e696d72971f27b67035363f12e3ff2c4 to your computer and use it in GitHub Desktop.
Save rojas-diego/e696d72971f27b67035363f12e3ff2c4 to your computer and use it in GitHub Desktop.
"""
Hyperparameter tuning and RL optimisation using RLlib, Tune and Neptune.
The script below does the following:
- Declare a NeptuneLogger which extends the tune.logger.Logger class to
send data to the neptune.ai API during training.
- Use Tune SearchAlgorithm to perform a grid search on the "lr" hyperparam.
- Define a Tune Scheduler to perform hyperparameter tune to obtain the
optimal learning rate.
- Setup and run the RLlib env and use Tune to run it.
- Report to Neptune the analysis report, including the best configuration.
"""
import neptune
import json
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.rllib.agents.ppo import PPOTrainer
from ray.tune.logger import Logger, DEFAULT_LOGGERS
import argparse
# Initialisation of the argparser
parser = argparse.ArgumentParser()
parser.add_argument("--training-iterations", type=int,
dest="training_iterations")
args = parser.parse_args()
# Fetching of API key
with open(".neptune-key", "r") as f:
data = f.read().replace("\n", "")
# Initialisation of the Neptune environment
neptune.init(
api_token=data,
project_qualified_name="{name/sandbox}",
)
# Creation of the Neptune experiment
neptune.create_experiment(
name="IntroductoryExp-v0",
tags=["alpha"],
description="Hyperparameter sweeping and RL optimisation using RLlib, Tune and Neptune.",
)
class NeptuneLogger(Logger):
"""
NeptuneLogger is an extension of the tune.logger.Logger class and aims at
being injected in the tune process as a Logger.
As of now, NeptuneLogger extracts the following from the result object.
- The episode_reward_mean: to plot the evolution of the model in
`ui.neptune.ai`.
"""
def _init(self):
pass
def close(self):
pass
def on_result(self, result):
neptune.log_metric("episode_reward_mean/" + result["trial_id"],
result["episode_reward_mean"])
# Setup the experiment configuration
TUNE_CONFIG = {
"env": "CartPole-v0",
"lr": tune.grid_search([1., 0.1, 0.01, 0.001]),
"log_level": "ERROR",
}
# We run the experiment passing NeptuneLogger as the only Logger.
analysis = tune.run(
PPOTrainer,
stop={
"training_iteration": args.training_iterations and args.training_iterations > 0 if args.training_iterations else 10,
},
scheduler=ASHAScheduler(metric="episode_reward_mean", mode="max"),
config=TUNE_CONFIG,
loggers=(NeptuneLogger, )
)
# Obtaining the best config from the Tune analysis object
best_config = analysis.get_best_config("episode_reward_mean", mode="max")
neptune.log_text("best_config_lr", "Optimal learning rate is " +
str(best_config["lr"]))
# Logging the best configuration to a JSON file.
with open('config.json', 'w') as fp:
json.dump(best_config, fp)
neptune.log_artifact("config.json")

Neptune.ai Introduction with RLlib and Tune

Hyperparameter tuning and training optimisation using RLlib, Tune and Neptune.

Description

This scripts implements a basic Logger for Tune in order to send training data to neptune.ai

Usage

You must have a .neptune-key file with a valid API key in the same directory as the python script. You must replace the project_qualified_name variable at line 43 by your own value.

Dependencies: pip install ray "ray[rllib]" "ray[tune]" neptune-client

Command line usage: python3.8 main.py [--training-iterations={unsigned int}]

Experience

Configuration

Basic RLlib CartPole-v0 experiment. Distributed accross 4 trials with different learning rates ([1., 0.1, 0.01, 0.001]) usine tune.grid_search().

TUNE_CONFIG = {
    "env": "CartPole-v0",
    "lr": tune.grid_search([1., 0.1, 0.01, 0.001]),
    "log_level": "ERROR",
}

Results

It was found that Optimal learning rate is 0.001.

Below are graphs depicting the evolution of the reward. We can see that the trial 18db8_00000 was evicted at iteration 4 by the ASHAScheduler due to poor results. Experience results

In a similar experience, the results were the following

+-----------------------------+------------+-------+-------+--------+------------------+-------+----------+
| Trial name                  | status     | loc   |    lr |   iter |   total time (s) |    ts |   reward |
|-----------------------------+------------+-------+-------+--------+------------------+-------+----------|
| PPO_CartPole-v0_708c7_00000 | TERMINATED |       | 1     |      1 |         10.398   |  4000 |  22.3296 |
| PPO_CartPole-v0_708c7_00001 | TERMINATED |       | 0.1   |     10 |         62.7582  | 40000 |  14.9813 |
| PPO_CartPole-v0_708c7_00002 | TERMINATED |       | 0.01  |      1 |          8.89217 |  4000 |  22.0552 |
| PPO_CartPole-v0_708c7_00003 | TERMINATED |       | 0.001 |     10 |         47.4788  | 40000 | 185.32   |
+-----------------------------+------------+-------+-------+--------+------------------+-------+----------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment