Created
February 15, 2023 22:20
-
-
Save ttumiel/c2132b424c49b76a62bafe7efef9923d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Dummy wrappers for testing a complex observation space | |
### Dummy Complex Obs Wrapper ### | |
class DummyComplex(gym.ObservationWrapper): | |
def __init__(self, env): | |
super().__init__(env) | |
self.observation_space = gym.spaces.Dict({"dummy": gym.spaces.Box(0, 255, (12,12)), "original": self.env.observation_space}) | |
def observation(self, observation): | |
return {"dummy": 255 * np.random.random((12,12)), "original": observation} | |
### Dummy Complex obs agent that handles the complex obs space that we defined ### | |
class ComplexObsAgent(nn.Module): | |
def __init__(self, envs): | |
super().__init__() | |
self.net = Agent(envs) | |
def get_value(self, x): | |
# Can use any of the obs values in `x` | |
return self.net.get_value(x["original"]) | |
def get_action_and_value(self, x, action=None): | |
return self.net.get_action_and_value(x["original"], action) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment