Skip to content

Instantly share code, notes, and snippets.

@rlan
Created October 6, 2021 03:50
Show Gist options
  • Save rlan/7e93acb251f208f0d737308e4725f412 to your computer and use it in GitHub Desktop.
Save rlan/7e93acb251f208f0d737308e4725f412 to your computer and use it in GitHub Desktop.
"""
Reference: https://github.com/ray-project/ray/blob/f8a91c7fad248b1c7f81fd6d30191ac930a92bc4/rllib/examples/env/simple_corridor.py
Fixes:
ValueError: ('Observation ({}) outside given space ({})!', array([0.]), Box([0.], [999.], (1,), float32))
"""
import gym
from gym.spaces import Box, Discrete
import numpy as np
class SimpleCorridor(gym.Env):
"""Example of a custom env in which you have to walk down a corridor.
You can configure the length of the corridor via the env config."""
def __init__(self, config=None):
config = config or {}
self.end_pos = config.get("corridor_length", 10)
self.start_pos = config.get("corridor_start", 0)
self.cur_pos = self.start_pos
self.action_space = Discrete(2)
self.observation_space = Box(self.start_pos, self.end_pos, shape=(1, ), dtype=np.float32)
def set_corridor_length(self, length):
self.end_pos = length
self.observation_space = Box(self.start_pos, self.end_pos, shape=(1, ), dtype=np.float32)
print("Updated corridor length to {}".format(length))
def reset(self):
self.cur_pos = self.start_pos
return np.full(self.observation_space.shape, self.cur_pos, dtype=np.float32)
def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1.0
self.cur_pos = max(self.cur_pos, self.start_pos)
elif action == 1:
self.cur_pos += 1.0
self.cur_pos = min(self.cur_pos, self.end_pos)
done = self.cur_pos >= self.end_pos
obs = np.full(self.observation_space.shape, self.cur_pos, dtype=np.float32)
return obs, 1 if done else 0, done, {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment