Last active
October 6, 2021 03:53
-
-
Save rlan/775607f436894ba596ecf1f0ad194e98 to your computer and use it in GitHub Desktop.
A starter example that trains, checkpoints and evaluates a RL algorithm in RLlib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Reference: https://github.com/ray-project/ray/blob/f8a91c7fad248b1c7f81fd6d30191ac930a92bc4/rllib/examples/env/simple_corridor.py | |
Fixes: | |
ValueError: ('Observation ({}) outside given space ({})!', array([0.]), Box([0.], [999.], (1,), float32)) | |
""" | |
import gym | |
from gym.spaces import Box, Discrete | |
import numpy as np | |
class SimpleCorridor(gym.Env): | |
"""Example of a custom env in which you have to walk down a corridor. | |
You can configure the length of the corridor via the env config.""" | |
def __init__(self, config=None): | |
config = config or {} | |
self.end_pos = config.get("corridor_length", 10) | |
self.start_pos = config.get("corridor_start", 0) | |
self.cur_pos = self.start_pos | |
self.action_space = Discrete(2) | |
self.observation_space = Box(self.start_pos, self.end_pos, shape=(1, ), dtype=np.float32) | |
def set_corridor_length(self, length): | |
self.end_pos = length | |
self.observation_space = Box(self.start_pos, self.end_pos, shape=(1, ), dtype=np.float32) | |
print("Updated corridor length to {}".format(length)) | |
def reset(self): | |
self.cur_pos = self.start_pos | |
return np.full(self.observation_space.shape, self.cur_pos, dtype=np.float32) | |
def step(self, action): | |
assert action in [0, 1], action | |
if action == 0 and self.cur_pos > 0: | |
self.cur_pos -= 1.0 | |
self.cur_pos = max(self.cur_pos, self.start_pos) | |
elif action == 1: | |
self.cur_pos += 1.0 | |
self.cur_pos = min(self.cur_pos, self.end_pos) | |
done = self.cur_pos >= self.end_pos | |
obs = np.full(self.observation_space.shape, self.cur_pos, dtype=np.float32) | |
return obs, 1 if done else 0, done, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
from simple_corridor import SimpleCorridor | |
import ray.rllib.agents.ppo as ppo | |
config={ | |
"env": SimpleCorridor, | |
"env_config": { | |
"corridor_length": 5, | |
} | |
} | |
stop = { | |
"training_iteration": 3, | |
} | |
ray.init() | |
# Train | |
results = ray.tune.run( | |
"PPO", | |
config=config, | |
stop=stop, | |
checkpoint_at_end = True, | |
) | |
# Get path to the trained model | |
results.default_metric = 'episode_reward_mean' | |
results.default_mode = 'max' | |
checkpoint_path = results.best_checkpoint | |
print(checkpoint_path) | |
# Load the trained model | |
agent = ppo.PPOTrainer( | |
config=config, | |
env=SimpleCorridor, | |
) | |
agent.restore(checkpoint_path) | |
# Run the trained model | |
env = SimpleCorridor(config=config["env_config"]) | |
done = False | |
obs = env.reset() | |
step = 0 | |
episode_reward = 0 | |
while not done: | |
action = agent.compute_action(obs, explore=False) | |
obs, reward, done, info = env.step(action) | |
episode_reward += reward | |
print(step, action, obs, reward, done, info) | |
step += 1 | |
print(episode_reward) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment