Skip to content

Instantly share code, notes, and snippets.

@christopherhesse
Created February 4, 2019 04:31
Show Gist options
  • Save christopherhesse/f5c8cda9ab40e62cdddbc31ff9802594 to your computer and use it in GitHub Desktop.
Save christopherhesse/f5c8cda9ab40e62cdddbc31ff9802594 to your computer and use it in GitHub Desktop.
import retro
import numpy as np
import gym.spaces
import argparse
import multiprocessing as mp
CHUNK_LENGTH = 128
def restore_state(env, state):
env.em.set_state(state)
env.data.reset()
env.data.update_ram()
def rollout(env, acts):
total_rew = 0.0
for act in acts:
_obs, rew, done, _info = env.step(act)
total_rew += rew
if done:
break
return total_rew
def chunk(L, length):
result = []
while True:
sublist = L[:length]
if len(sublist) == 0:
break
L = L[length:]
result.append(sublist)
return result
def partition(L, pieces):
return chunk(L, len(L) // pieces + 1)
def check_env_helper(make_env, acts, verbose, out_success):
# do rollouts and get reference values
env = make_env()
env.reset()
in_states = [env.em.get_state()]
in_acts = chunk(acts, CHUNK_LENGTH)
out_rews = []
out_rams = []
for acts in in_acts:
out_rews.append(rollout(env, acts))
out_rams.append(env.get_ram())
in_states.append(env.em.get_state())
in_states.pop() # remove extra final state since there are no actions after it
success = True
for start_idx in range(len(in_states)):
if verbose:
print(start_idx+1, len(in_states))
restore_state(env, in_states[start_idx])
for offset, acts in enumerate(in_acts[start_idx:]):
if not np.array_equal(rollout(env, acts), out_rews[start_idx+offset]):
print('failed rew')
success = False
if not np.array_equal(env.get_ram(), out_rams[start_idx+offset]):
print('failed ram')
success = False
env.close()
out_success.value = success
def check_env(make_env, acts, verbose=False, timeout=None):
out_success = mp.Value('b', False)
p = mp.Process(target=check_env_helper, args=(make_env, acts, verbose, out_success), daemon=True)
p.start()
p.join(timeout)
if p.is_alive():
print('failed to finish in time')
p.terminate()
p.join()
return False
return bool(out_success.value)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--suffix', default='', help='run against games matching this suffix')
parser.add_argument('--movie-file', help='load a bk2 and use states obtained from replaying actions from the bk2')
args = parser.parse_args()
if args.movie_file is None:
games = [g for g in sorted(retro.data.list_games()) if g.endswith(args.suffix)]
failed_games = []
for game in games:
print(game)
def make_env():
return retro.make(game=game)
gym.spaces.seed(0)
env = make_env()
acts = [env.action_space.sample() for _ in range(CHUNK_LENGTH * 2)]
env.close()
if not check_env(make_env, acts, timeout=16):
failed_games.append(game)
for game in failed_games:
print('failed:', game)
elif args.movie_file is not None:
movie = retro.Movie(args.movie_file)
movie.step()
def make_env():
env = retro.make(movie.get_game(), state=retro.State.DEFAULT, use_restricted_actions=retro.Actions.ALL)
env.initial_state = movie.get_state()
return env
acts = []
while movie.step():
act = []
for p in range(movie.players):
for i in range(env.num_buttons):
act.append(movie.get_key(i, p))
acts.append(act)
check_env(env, acts, verbose=True)
else:
raise Exception('must specify --suffix or --movie-file')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment