Skip to content

Instantly share code, notes, and snippets.

@christopherhesse
Created January 7, 2019 22:11
Show Gist options
  • Save christopherhesse/dc5a7ed99704870592d7c3264e0dbd6c to your computer and use it in GitHub Desktop.
Save christopherhesse/dc5a7ed99704870592d7c3264e0dbd6c to your computer and use it in GitHub Desktop.
def scalar_adapter(venv_cls):
class ScalarEnv:
def __init__(self, **kwargs) -> None:
self._venv = venv_cls(num_envs=1, **kwargs)
self.observation_space = self._venv.observation_space
self.action_space = self._venv.action_space
self.metadata = self._venv.metadata
self.spec = self._venv.spec
def _process_obs(self, obs):
if isinstance(obs, np.ndarray):
return obs[0]
else:
scalar_obs = {}
for k, v in obs.items():
scalar_obs[k] = v[0]
return scalar_obs
def reset(self):
obs = self._venv.reset()
return self._process_obs(obs)
def step(self, action):
if isinstance(self.action_space, gym.spaces.Discrete):
action = np.array([action], dtype=self._venv.action_space.dtype)
else:
action = np.expand_dims(action, axis=0)
obs, rews, dones, infos = self._venv.step(action)
return self._process_obs(obs), rews[0], dones[0], infos[0]
def render(self, mode='human'):
result = self._venv.render(mode=mode)
if mode == 'human':
return result
else:
return result[0]
def close(self):
return self._venv.close()
def __repr__(self):
return f'<ScalarEnv venv={self._venv}>'
def __getattr__(self, name):
return getattr(self._venv, name)
return ScalarEnv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment