Created
November 13, 2021 01:18
-
-
Save mhauskn/861e4983f54a435013f66e9ab44ea308 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
#from tabulate import tabulate | |
import sys | |
from termcolor import colored | |
import jericho | |
import pickle | |
import random | |
from os.path import join as pjoin | |
from jericho.util import clean | |
from jericho.defines import BASIC_ACTIONS | |
import numpy as np | |
def get_detected_diffs(env): | |
interactive_objs = env._identify_interactive_objects(use_object_tree=True) | |
best_obj_names = env._score_object_names(interactive_objs) | |
candidate_actions = env.act_gen.generate_actions(best_obj_names) | |
diff2acts = env._filter_candidate_actions(candidate_actions, use_ctypes=True, use_parallel=True) | |
return diff2acts | |
def analyze_step(idx, env, gold_act): | |
""" | |
Checks to ensure the diff generated by the walkthrough action (gold_act) is amoung the diffs | |
considered by get_valid_actions. | |
""" | |
diff2acts = get_detected_diffs(env) | |
obs, _, _, _ = env.step(gold_act) | |
gold_diff = env._get_world_diff() | |
if gold_act.startswith('x ') or gold_act.startswith('examine ') or gold_act == 'z': | |
return | |
if not env._world_changed(): | |
print(colored('{}. NoWorldChange gold_act: {}, obs: {}'.format(idx, gold_act, clean(obs)), 'magenta')) | |
if gold_diff not in diff2acts: | |
print(colored('{}. gold_act: {}-{} not in valids. Obs: {}'.format(idx, gold_act, gold_diff, clean(obs)), 'red')) | |
def run_walkthrough(): | |
""" | |
Runs the walkthrough checking for two main things: | |
1) Is the world_diff from the walkthrough action among the diffs in get_valid_actions()? | |
2) Is the world state hash for the current state been encountered before or is unique? | |
""" | |
env = jericho.FrotzEnv(ROM) | |
obs, _ = env.reset() | |
walkthrough = env.get_walkthrough() | |
hashes = {} | |
for idx, gold_act in enumerate(walkthrough): | |
whash = env.get_world_state_hash() | |
if not whash in hashes: | |
hashes[whash] = [idx] | |
else: | |
print(colored('{}: {} States with same hash: {}'.format(idx, obs, hashes[whash]), 'cyan')) | |
hashes[whash].append(idx) | |
print(f'Step {idx} - {gold_act}') | |
analyze_step(idx, env, gold_act) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('rom', type=str, help='Rom File') | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args = parse_args() | |
global ROM | |
ROM = args.rom | |
run_walkthrough() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment