Skip to content

Instantly share code, notes, and snippets.

@BenedictWilkins
Created October 28, 2020 15:50
Show Gist options
  • Save BenedictWilkins/d8ecc4c10cc5032ebf3022895e48506f to your computer and use it in GitHub Desktop.
Save BenedictWilkins/d8ecc4c10cc5032ebf3022895e48506f to your computer and use it in GitHub Desktop.
OpenAI Gym observation slicing
import numpy as np
import gym
class BoxSlice(gym.spaces.Box):
def __init__(self, box, s_ = np.s_[:,:,:]):
assert isinstance(box, gym.spaces.Box)
self.__box = box
self.__slice = s_
super(BoxSlice, self).__init__(box.low, box.high, dtype=box.dtype)
def __getitem__(self, i):
return BoxSlice(self, i)
@property
def s_(self):
return self.__slice
@property
def low(self):
return self.__box.low[self.__slice]
@low.setter
def low(self, value):
pass
@property
def high(self):
return self.__box.high[self.__slice]
@high.setter
def high(self, value):
pass
@property
def shape(self):
return self.__box.low[self.__slice].shape
@shape.setter
def shape(self, value):
pass
class Slice(gym.Wrapper):
def __init__(self, env, s_=np.s_[:,:,:]):
super(Slice, self).__init__(env)
self.observation_space = BoxSlice(env.observation_space, s_=s_)
def __getitem__(self, i):
return Slice(self, s_=i)
def step(self, action, *args, **kwargs):
observation, *rest = self.env.step(action, *args, **kwargs)
observation = observation[self.observation_space.s_]
return (observation, *rest)
def reset(self, *args, **kwargs):
observation = self.env.reset(*args, **kwargs)
observation = observation[self.observation_space.s_]
return observation
if __name__ == "__main__":
env = gym.make("Pong-v0")
env = Slice(env)
print(env[1:].observation_space) # crop 1 pixel
print(env[10:20,::2,:].observation_space) # ?!
print(env.reset().shape) # normal observation
print(env[:,:,:1].reset().shape) # only red channel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment