-
-
Save christopherhesse/fd8a6593df61bafb4af6ef8dcdca11b2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
class MoreDeterministicRetroWrapper(gym.Wrapper): | |
""" | |
Save/restore state on each step to avoid de-sync | |
It's possible that reward and done will not be correct if they | |
depend on lua state (e.g. Sonic "contest" scenario) | |
For most emulated systems this is 10%-50% slower, for Atari2600 it is | |
60x slower. It's unclear why stella is slow slow to save/load a state. | |
If other wrappers have state (such as Timelimit), they would need to be extended | |
to support get_state() and reset(state=state), and then this class would need | |
to make sure parent methods are called. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._last_obs = None | |
self._done = False | |
def reset(self, state=None): | |
self._done = False | |
if state is not None: | |
em_state, self._last_obs = state | |
self.unwrapped.em.set_state(em_state) | |
self.unwrapped.data.reset() | |
self.unwrapped.data.update_ram() | |
else: | |
self._last_obs = self.env.reset() | |
return self._last_obs | |
def step(self, act): | |
self.reset(state=self.get_state()) | |
self._last_obs, rew, self._done, info = self.step(act) | |
return self._last_obs, rew, self._done, info | |
def get_state(self): | |
assert not self._done, "cannot store a terminal state" | |
return (self.unwrapped.em.get_state(), self._last_obs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment