Skip to content

Instantly share code, notes, and snippets.

@kepricon
Created February 8, 2021 19:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kepricon/516800ad7d8ef1b0b23df429c4d49490 to your computer and use it in GitHub Desktop.
Save kepricon/516800ad7d8ef1b0b23df429c4d49490 to your computer and use it in GitHub Desktop.
import random
import ray
from ray.tune import run, sample_from
from ray.tune.schedulers import PopulationBasedTraining
if __name__ == "__main__":
class Stopper:
def __init__(self):
self.too_many_iter = False
def stop(self, trial_id, result):
self.too_many_iter = result['training_iteration'] >= 10
if self.too_many_iter:
return True
# Postprocess the perturbed config to ensure it's still valid
def explore(config):
if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
config["train_batch_size"] = config["sgd_minibatch_size"] * 2
if config["num_sgd_iter"] < 1:
config["num_sgd_iter"] = 1
return config
pbt = PopulationBasedTraining(
time_attr="time_total_s",
metric="episode_reward_mean",
mode="max",
perturbation_interval=120,
resample_probability=0.25,
# Specifies the mutations of these hyperparams
hyperparam_mutations={
"lambda": lambda: random.uniform(0.9, 1.0),
"clip_param": lambda: random.uniform(0.01, 0.5),
"lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
"num_sgd_iter": lambda: random.randint(1, 30),
"sgd_minibatch_size": lambda: random.randint(128, 16384),
"train_batch_size": lambda: random.randint(200, 1600),
},
custom_explore_fn=explore)
ray.init()
run(
"PPO",
name="cartpole",
scheduler=pbt,
num_samples=1,
config={
"env": "CartPole-v0",
"kl_coeff": 1.0,
"num_workers": 1,
"num_gpus": 0,
"model": {
"free_log_std": True
},
# These params are tuned from a fixed starting value.
"lambda": 0.95,
"clip_param": 0.2,
"lr": 1e-4,
# These params start off randomly drawn from a set.
"num_sgd_iter": sample_from(
lambda spec: random.choice([10, 20])),
"sgd_minibatch_size": sample_from(
lambda spec: random.choice([128, 512])),
"train_batch_size": sample_from(
lambda spec: random.choice([1000, 2000]))
},
stop = Stopper().stop,
local_dir = '/tmp/PPO',
export_formats = ['model']
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment