Last active
April 14, 2019 14:53
-
-
Save EliasHasle/5958fceccbb10f8279c860caa8c31534 to your computer and use it in GitHub Desktop.
Retro SuperMarioBros-Nes wrapper that skips busy frames and converts to discrete action space (alpha, not thoroughly tested)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Inspired by | |
#https://github.com/Kautenja/nes-py/blob/master/nes_py/wrappers/binary_to_discrete_space_env.py | |
#and | |
#https://github.com/openai/retro-baselines/blob/master/agents/sonic_util.py | |
#but copying directly from neither. | |
#Requires retro. | |
#Requires SuperMarioBros NES rom (USA/Japan). Once you have it, can be imported to retro: | |
#> python -m retro.import directory_name | |
#(identifies the ROM if it is there and has the .nes extension.) | |
import gym | |
import retro | |
import numpy as np | |
from random import random,randrange | |
#Add step method with automatic skipping of cutscenes based on state hints. | |
#self.env.em.get_state() returns a byte string holding the state. | |
#Use memory addresses and tricks from | |
#https://github.com/Kautenja/gym-super-mario-bros/blob/master/gym_super_mario_bros/smb_env.py | |
#self.env.em.set_state(), self.env.em.step() | |
#Is there a way to modify only part of the memory in-place? | |
#I think not, but skipping many frames at once can be valuable anyway. | |
class Retro_SMB_Wrapper(gym.ActionWrapper): | |
def __init__(self, env, base_frameskip=4, stickiness=0.2, verbose=False): | |
super(Retro_SMB_Wrapper, self).__init__(env) | |
button_indices = {"B": 0, None: 1, "select": 2, "start": 3, "up": 4, "down": 5, "left": 6, "right": 7, "A": 8} | |
#Equivalent to @Kautenja's COMPLEX_MOVEMENT: | |
actions = [[None],\ | |
["right"],\ | |
["right","A"],\ | |
["right", "B"],\ | |
["right", "A", "B"],\ | |
["A"],\ | |
["left"],\ | |
["left", "A"],\ | |
["left", "B"],\ | |
["left", "A", "B"],\ | |
["down"],\ | |
["up"]] | |
self.action_space = gym.spaces.Discrete(len(actions)) | |
self._action_map = {} | |
for action, button_list in enumerate(actions): | |
bool_action = np.array([False]*len(button_indices)) | |
for button_name in button_list: | |
bool_action[button_indices[button_name]] = True | |
self._action_map[action] = bool_action | |
self.base_frameskip = base_frameskip | |
## Expected frameskip is a geometric sum: | |
expected_frameskip = base_frameskip/(1-stickiness) | |
if verbose: | |
print("Expected frameskip: %.2f" % expected_frameskip) | |
self.stickiness = stickiness | |
self.previous_action = 0 | |
self.verbose = verbose | |
def reset(self): | |
return self.env.reset() | |
def action(self, a): | |
return self._action_map[a] | |
#Very experimental, but seems to work quite well. | |
#The skipping of busy states saves a lot of rendering time. | |
#The skipping may not be perfect yet, because two interactions | |
#with skipping are made for every death. | |
#Also, not all skipping modes are tested yet. | |
#I have not paid much attention to the timer yet. | |
def step(self, a): | |
#Apply action stickiness. This discourages path memorization | |
#and (in some cases) encourages safe decisions, such as | |
#jumping some time before Mario reaches the edge before a gap. | |
#before edges. Note that applying stickiness here may be | |
#suboptimal. Applying it in the policy instead may save | |
#an otherwise wasted policy computation. Also, when doing it | |
#here, the PG algorithm will not know that its action is not | |
#carried out, so may have a wrong gradient as a result. | |
if random() < self.stickiness: | |
a = self.previous_action | |
#Apply deterministic frameskip: | |
rew = 0 | |
for i in range(self.base_frameskip): | |
obs,r,done,info = self.env.step(self._action_map[a]) | |
rew += r | |
#safe? | |
if done: | |
return obs,rew,done,info | |
emulator = self.env.em | |
s = emulator.get_state() | |
#I think this is the offset wrt. the addresses used by @Kautenja | |
offset = 93 | |
sparse_update = {} | |
player_state_index = 0x000e+offset | |
player_state = s[player_state_index] | |
falling_in_hole = s[0x00b5+offset]>1 | |
busy = (player_state < 8 and player_state != 6) | |
world_done = s[0x0770+offset]==2 | |
in_change_area = s[0x06DE+offset]>1 | |
#Skip busy states and cut scenes: | |
if busy or world_done: | |
sparse_update[0x07A0+offset] = 0 | |
if in_change_area: | |
sparse_update[0x06DE+offset] = 1 | |
if player_state==0x0B or falling_in_hole: | |
sparse_update[player_state_index] = 0x06 | |
if sparse_update: | |
if self.verbose: | |
print("Skipping busy frames.") | |
li = list(s) | |
for key,value in sparse_update.items(): | |
li[key] = value | |
emulator.set_state(bytes(li)) | |
if world_done: | |
fixed_time_ones = s[0x07FA+offset] | |
while True: | |
emulator.step() | |
s = emulator.get_state() | |
if s[0x07FA+offset] != fixed_time_ones: | |
break | |
if busy or world_done: | |
while True: | |
emulator.step() | |
s = emulator.get_state() | |
player_state = s[player_state_index] | |
busy = (player_state < 8 and player_state != 6) | |
world_done = s[0x0770+offset]==2 | |
if not busy and not world_done: | |
break | |
obs = emulator.get_screen() #shape matches | |
return obs,rew,done,info | |
#TEST | |
if __name__ == "__main__": | |
env = retro.make("SuperMarioBros-Nes") | |
env = Retro_SMB_Wrapper(env) | |
env.reset() | |
games = 0 | |
while games < 2: | |
obs,reward,done,info = env.step(randrange(12)) | |
env.render() | |
games += done | |
if done: | |
env.reset() | |
print(games) | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment