Skip to content

Instantly share code, notes, and snippets.

@ericl
Created August 4, 2020 23:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericl/a049ab150596d183ca28dac1a2f60f9c to your computer and use it in GitHub Desktop.
Save ericl/a049ab150596d183ca28dac1a2f60f9c to your computer and use it in GitHub Desktop.
import argparse
import gym
from gym.spaces import Discrete, Box
import numpy as np
import ray
from ray import tune
from ray.tune import grid_search
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_learning_achieved
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--torch", action="store_true")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=50)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--stop-reward", type=float, default=0.1)
class SimpleCorridor(gym.Env):
"""Example of a custom env in which you have to walk down a corridor.
You can configure the length of the corridor via the env config."""
def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.action_space = Discrete(2)
self.observation_space = Discrete(self.end_pos + 1)
def reset(self):
self.cur_pos = 0
return self.cur_pos
def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
elif action == 1:
self.cur_pos += 1
done = self.cur_pos >= self.end_pos
return self.cur_pos, 1.0 if done else -0.1, done, {}
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
config = {
"env": SimpleCorridor, # or "corridor" if registered above
"env_config": {
"corridor_length": 5,
},
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment