-
-
Save christopherhesse/f5c8cda9ab40e62cdddbc31ff9802594 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import retro | |
import numpy as np | |
import gym.spaces | |
import argparse | |
import multiprocessing as mp | |
CHUNK_LENGTH = 128 | |
def restore_state(env, state): | |
env.em.set_state(state) | |
env.data.reset() | |
env.data.update_ram() | |
def rollout(env, acts): | |
total_rew = 0.0 | |
for act in acts: | |
_obs, rew, done, _info = env.step(act) | |
total_rew += rew | |
if done: | |
break | |
return total_rew | |
def chunk(L, length): | |
result = [] | |
while True: | |
sublist = L[:length] | |
if len(sublist) == 0: | |
break | |
L = L[length:] | |
result.append(sublist) | |
return result | |
def partition(L, pieces): | |
return chunk(L, len(L) // pieces + 1) | |
def check_env_helper(make_env, acts, verbose, out_success): | |
# do rollouts and get reference values | |
env = make_env() | |
env.reset() | |
in_states = [env.em.get_state()] | |
in_acts = chunk(acts, CHUNK_LENGTH) | |
out_rews = [] | |
out_rams = [] | |
for acts in in_acts: | |
out_rews.append(rollout(env, acts)) | |
out_rams.append(env.get_ram()) | |
in_states.append(env.em.get_state()) | |
in_states.pop() # remove extra final state since there are no actions after it | |
success = True | |
for start_idx in range(len(in_states)): | |
if verbose: | |
print(start_idx+1, len(in_states)) | |
restore_state(env, in_states[start_idx]) | |
for offset, acts in enumerate(in_acts[start_idx:]): | |
if not np.array_equal(rollout(env, acts), out_rews[start_idx+offset]): | |
print('failed rew') | |
success = False | |
if not np.array_equal(env.get_ram(), out_rams[start_idx+offset]): | |
print('failed ram') | |
success = False | |
env.close() | |
out_success.value = success | |
def check_env(make_env, acts, verbose=False, timeout=None): | |
out_success = mp.Value('b', False) | |
p = mp.Process(target=check_env_helper, args=(make_env, acts, verbose, out_success), daemon=True) | |
p.start() | |
p.join(timeout) | |
if p.is_alive(): | |
print('failed to finish in time') | |
p.terminate() | |
p.join() | |
return False | |
return bool(out_success.value) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--suffix', default='', help='run against games matching this suffix') | |
parser.add_argument('--movie-file', help='load a bk2 and use states obtained from replaying actions from the bk2') | |
args = parser.parse_args() | |
if args.movie_file is None: | |
games = [g for g in sorted(retro.data.list_games()) if g.endswith(args.suffix)] | |
failed_games = [] | |
for game in games: | |
print(game) | |
def make_env(): | |
return retro.make(game=game) | |
gym.spaces.seed(0) | |
env = make_env() | |
acts = [env.action_space.sample() for _ in range(CHUNK_LENGTH * 2)] | |
env.close() | |
if not check_env(make_env, acts, timeout=16): | |
failed_games.append(game) | |
for game in failed_games: | |
print('failed:', game) | |
elif args.movie_file is not None: | |
movie = retro.Movie(args.movie_file) | |
movie.step() | |
def make_env(): | |
env = retro.make(movie.get_game(), state=retro.State.DEFAULT, use_restricted_actions=retro.Actions.ALL) | |
env.initial_state = movie.get_state() | |
return env | |
acts = [] | |
while movie.step(): | |
act = [] | |
for p in range(movie.players): | |
for i in range(env.num_buttons): | |
act.append(movie.get_key(i, p)) | |
acts.append(act) | |
check_env(env, acts, verbose=True) | |
else: | |
raise Exception('must specify --suffix or --movie-file') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment