Skip to content

Instantly share code, notes, and snippets.

@rlan
Last active October 6, 2021 03:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rlan/775607f436894ba596ecf1f0ad194e98 to your computer and use it in GitHub Desktop.
Save rlan/775607f436894ba596ecf1f0ad194e98 to your computer and use it in GitHub Desktop.
A starter example that trains, checkpoints and evaluates a RL algorithm in RLlib
"""
Reference: https://github.com/ray-project/ray/blob/f8a91c7fad248b1c7f81fd6d30191ac930a92bc4/rllib/examples/env/simple_corridor.py
Fixes:
ValueError: ('Observation ({}) outside given space ({})!', array([0.]), Box([0.], [999.], (1,), float32))
"""
import gym
from gym.spaces import Box, Discrete
import numpy as np
class SimpleCorridor(gym.Env):
"""Example of a custom env in which you have to walk down a corridor.
You can configure the length of the corridor via the env config."""
def __init__(self, config=None):
config = config or {}
self.end_pos = config.get("corridor_length", 10)
self.start_pos = config.get("corridor_start", 0)
self.cur_pos = self.start_pos
self.action_space = Discrete(2)
self.observation_space = Box(self.start_pos, self.end_pos, shape=(1, ), dtype=np.float32)
def set_corridor_length(self, length):
self.end_pos = length
self.observation_space = Box(self.start_pos, self.end_pos, shape=(1, ), dtype=np.float32)
print("Updated corridor length to {}".format(length))
def reset(self):
self.cur_pos = self.start_pos
return np.full(self.observation_space.shape, self.cur_pos, dtype=np.float32)
def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1.0
self.cur_pos = max(self.cur_pos, self.start_pos)
elif action == 1:
self.cur_pos += 1.0
self.cur_pos = min(self.cur_pos, self.end_pos)
done = self.cur_pos >= self.end_pos
obs = np.full(self.observation_space.shape, self.cur_pos, dtype=np.float32)
return obs, 1 if done else 0, done, {}
import ray
from simple_corridor import SimpleCorridor
import ray.rllib.agents.ppo as ppo
config={
"env": SimpleCorridor,
"env_config": {
"corridor_length": 5,
}
}
stop = {
"training_iteration": 3,
}
ray.init()
# Train
results = ray.tune.run(
"PPO",
config=config,
stop=stop,
checkpoint_at_end = True,
)
# Get path to the trained model
results.default_metric = 'episode_reward_mean'
results.default_mode = 'max'
checkpoint_path = results.best_checkpoint
print(checkpoint_path)
# Load the trained model
agent = ppo.PPOTrainer(
config=config,
env=SimpleCorridor,
)
agent.restore(checkpoint_path)
# Run the trained model
env = SimpleCorridor(config=config["env_config"])
done = False
obs = env.reset()
step = 0
episode_reward = 0
while not done:
action = agent.compute_action(obs, explore=False)
obs, reward, done, info = env.step(action)
episode_reward += reward
print(step, action, obs, reward, done, info)
step += 1
print(episode_reward)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment