Skip to content

Instantly share code, notes, and snippets.

@ttumiel
Created February 15, 2023 22:20
Show Gist options
  • Save ttumiel/c2132b424c49b76a62bafe7efef9923d to your computer and use it in GitHub Desktop.
Save ttumiel/c2132b424c49b76a62bafe7efef9923d to your computer and use it in GitHub Desktop.
# Dummy wrappers for testing a complex observation space
### Dummy Complex Obs Wrapper ###
class DummyComplex(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Dict({"dummy": gym.spaces.Box(0, 255, (12,12)), "original": self.env.observation_space})
def observation(self, observation):
return {"dummy": 255 * np.random.random((12,12)), "original": observation}
### Dummy Complex obs agent that handles the complex obs space that we defined ###
class ComplexObsAgent(nn.Module):
def __init__(self, envs):
super().__init__()
self.net = Agent(envs)
def get_value(self, x):
# Can use any of the obs values in `x`
return self.net.get_value(x["original"])
def get_action_and_value(self, x, action=None):
return self.net.get_action_and_value(x["original"], action)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment