Skip to content

Instantly share code, notes, and snippets.

@christopherhesse
Created February 4, 2019 23:37
Show Gist options
  • Save christopherhesse/c55c1b2e130eea06e45080d092e78951 to your computer and use it in GitHub Desktop.
Save christopherhesse/c55c1b2e130eea06e45080d092e78951 to your computer and use it in GitHub Desktop.
import retro
import retro.retro_env
import numpy as np
import gym.spaces
import gym.envs.atari
import argparse
import multiprocessing as mp
from fnmatch import fnmatch
CHUNK_LENGTH = 128
ALE_GAMES = ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis',
'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival',
'centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk',
'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar',
'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master',
'montezuma_revenge', 'ms_pacman', 'name_this_game', 'phoenix', 'pitfall', 'pong', 'pooyan',
'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing',
'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down',
'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon']
class RetroEnv(retro.retro_env.RetroEnv):
def get_state(self):
return self.unwrapped.em.get_state()
def set_state(self, state):
self.unwrapped.em.set_state(state)
self.unwrapped.data.reset()
self.unwrapped.data.update_ram()
class ALEEnv(gym.envs.atari.AtariEnv):
def get_ram(self):
return self.unwrapped._get_ram()
def get_state(self):
return self.unwrapped.clone_full_state()
def set_state(self, state):
self.unwrapped.restore_full_state(state)
def rollout(env, acts):
total_rew = 0.0
for act in acts:
_obs, rew, done, _info = env.step(act)
total_rew += rew
if done:
break
return total_rew
def chunk(L, length):
result = []
while True:
sublist = L[:length]
if len(sublist) == 0:
break
L = L[length:]
result.append(sublist)
return result
def partition(L, pieces):
return chunk(L, len(L) // pieces + 1)
def check_env_helper(make_env, acts, verbose, out_success):
# do rollouts and get reference values
env = make_env()
env.reset()
in_states = [env.get_state()]
in_acts = chunk(acts, CHUNK_LENGTH)
out_rews = []
out_rams = []
for acts in in_acts:
out_rews.append(rollout(env, acts))
out_rams.append(env.get_ram())
in_states.append(env.get_state())
in_states.pop() # remove extra final state since there are no actions after it
success = True
for start_idx in range(len(in_states)):
if verbose:
print(start_idx+1, len(in_states))
env.set_state(in_states[start_idx])
for offset, acts in enumerate(in_acts[start_idx:]):
if not np.array_equal(rollout(env, acts), out_rews[start_idx+offset]):
print('failed rew')
success = False
if not np.array_equal(env.get_ram(), out_rams[start_idx+offset]):
print('failed ram')
success = False
env.close()
out_success.value = success
def check_env(make_env, acts, verbose=False, timeout=None):
out_success = mp.Value('b', False)
p = mp.Process(target=check_env_helper, args=(make_env, acts, verbose, out_success), daemon=True)
p.start()
p.join(timeout)
if p.is_alive():
print('failed to finish in time')
p.terminate()
p.join()
return False
return bool(out_success.value)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--pattern', default='', help='run against games matching this pattern')
parser.add_argument('--movie-file', help='load a bk2 and use states obtained from replaying actions from the bk2')
args = parser.parse_args()
if args.movie_file is None:
games = [g for g in sorted(retro.data.list_games()) if fnmatch(g.lower(), args.pattern.lower())]
games += ['ale:' + g for g in ALE_GAMES if fnmatch(g, args.pattern.lower())]
failed_games = []
for game in games:
print(game)
def make_env():
if game.startswith('ale:'):
return ALEEnv(game.split(':')[1], obs_type='image', frameskip=1)
else:
return RetroEnv(game=game)
gym.spaces.seed(0)
env = make_env()
acts = [env.action_space.sample() for _ in range(CHUNK_LENGTH * 16)]
env.close()
if not check_env(make_env, acts, timeout=16, verbose=len(acts) > 1000):
failed_games.append(game)
for game in failed_games:
print('failed:', game)
elif args.movie_file is not None:
movie = retro.Movie(args.movie_file)
movie.step()
def make_env():
env = retro.make(movie.get_game(), state=retro.State.DEFAULT, use_restricted_actions=retro.Actions.ALL)
env.initial_state = movie.get_state()
return env
acts = []
while movie.step():
act = []
for p in range(movie.players):
for i in range(env.num_buttons):
act.append(movie.get_key(i, p))
acts.append(act)
check_env(env, acts, verbose=True)
else:
raise Exception('must specify --suffix or --movie-file')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment