Skip to content

Instantly share code, notes, and snippets.

@EliasHasle
Last active April 14, 2019 14:53
Show Gist options
  • Save EliasHasle/5958fceccbb10f8279c860caa8c31534 to your computer and use it in GitHub Desktop.
Save EliasHasle/5958fceccbb10f8279c860caa8c31534 to your computer and use it in GitHub Desktop.
Retro SuperMarioBros-Nes wrapper that skips busy frames and converts to discrete action space (alpha, not thoroughly tested)
#Inspired by
#https://github.com/Kautenja/nes-py/blob/master/nes_py/wrappers/binary_to_discrete_space_env.py
#and
#https://github.com/openai/retro-baselines/blob/master/agents/sonic_util.py
#but copying directly from neither.
#Requires retro.
#Requires SuperMarioBros NES rom (USA/Japan). Once you have it, can be imported to retro:
#> python -m retro.import directory_name
#(identifies the ROM if it is there and has the .nes extension.)
import gym
import retro
import numpy as np
from random import random,randrange
#Add step method with automatic skipping of cutscenes based on state hints.
#self.env.em.get_state() returns a byte string holding the state.
#Use memory addresses and tricks from
#https://github.com/Kautenja/gym-super-mario-bros/blob/master/gym_super_mario_bros/smb_env.py
#self.env.em.set_state(), self.env.em.step()
#Is there a way to modify only part of the memory in-place?
#I think not, but skipping many frames at once can be valuable anyway.
class Retro_SMB_Wrapper(gym.ActionWrapper):
def __init__(self, env, base_frameskip=4, stickiness=0.2, verbose=False):
super(Retro_SMB_Wrapper, self).__init__(env)
button_indices = {"B": 0, None: 1, "select": 2, "start": 3, "up": 4, "down": 5, "left": 6, "right": 7, "A": 8}
#Equivalent to @Kautenja's COMPLEX_MOVEMENT:
actions = [[None],\
["right"],\
["right","A"],\
["right", "B"],\
["right", "A", "B"],\
["A"],\
["left"],\
["left", "A"],\
["left", "B"],\
["left", "A", "B"],\
["down"],\
["up"]]
self.action_space = gym.spaces.Discrete(len(actions))
self._action_map = {}
for action, button_list in enumerate(actions):
bool_action = np.array([False]*len(button_indices))
for button_name in button_list:
bool_action[button_indices[button_name]] = True
self._action_map[action] = bool_action
self.base_frameskip = base_frameskip
## Expected frameskip is a geometric sum:
expected_frameskip = base_frameskip/(1-stickiness)
if verbose:
print("Expected frameskip: %.2f" % expected_frameskip)
self.stickiness = stickiness
self.previous_action = 0
self.verbose = verbose
def reset(self):
return self.env.reset()
def action(self, a):
return self._action_map[a]
#Very experimental, but seems to work quite well.
#The skipping of busy states saves a lot of rendering time.
#The skipping may not be perfect yet, because two interactions
#with skipping are made for every death.
#Also, not all skipping modes are tested yet.
#I have not paid much attention to the timer yet.
def step(self, a):
#Apply action stickiness. This discourages path memorization
#and (in some cases) encourages safe decisions, such as
#jumping some time before Mario reaches the edge before a gap.
#before edges. Note that applying stickiness here may be
#suboptimal. Applying it in the policy instead may save
#an otherwise wasted policy computation. Also, when doing it
#here, the PG algorithm will not know that its action is not
#carried out, so may have a wrong gradient as a result.
if random() < self.stickiness:
a = self.previous_action
#Apply deterministic frameskip:
rew = 0
for i in range(self.base_frameskip):
obs,r,done,info = self.env.step(self._action_map[a])
rew += r
#safe?
if done:
return obs,rew,done,info
emulator = self.env.em
s = emulator.get_state()
#I think this is the offset wrt. the addresses used by @Kautenja
offset = 93
sparse_update = {}
player_state_index = 0x000e+offset
player_state = s[player_state_index]
falling_in_hole = s[0x00b5+offset]>1
busy = (player_state < 8 and player_state != 6)
world_done = s[0x0770+offset]==2
in_change_area = s[0x06DE+offset]>1
#Skip busy states and cut scenes:
if busy or world_done:
sparse_update[0x07A0+offset] = 0
if in_change_area:
sparse_update[0x06DE+offset] = 1
if player_state==0x0B or falling_in_hole:
sparse_update[player_state_index] = 0x06
if sparse_update:
if self.verbose:
print("Skipping busy frames.")
li = list(s)
for key,value in sparse_update.items():
li[key] = value
emulator.set_state(bytes(li))
if world_done:
fixed_time_ones = s[0x07FA+offset]
while True:
emulator.step()
s = emulator.get_state()
if s[0x07FA+offset] != fixed_time_ones:
break
if busy or world_done:
while True:
emulator.step()
s = emulator.get_state()
player_state = s[player_state_index]
busy = (player_state < 8 and player_state != 6)
world_done = s[0x0770+offset]==2
if not busy and not world_done:
break
obs = emulator.get_screen() #shape matches
return obs,rew,done,info
#TEST
if __name__ == "__main__":
env = retro.make("SuperMarioBros-Nes")
env = Retro_SMB_Wrapper(env)
env.reset()
games = 0
while games < 2:
obs,reward,done,info = env.step(randrange(12))
env.render()
games += done
if done:
env.reset()
print(games)
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment