Skip to content

Instantly share code, notes, and snippets.

@mhauskn
Created November 13, 2021 01:18
Show Gist options
  • Save mhauskn/861e4983f54a435013f66e9ab44ea308 to your computer and use it in GitHub Desktop.
Save mhauskn/861e4983f54a435013f66e9ab44ea308 to your computer and use it in GitHub Desktop.
import argparse
#from tabulate import tabulate
import sys
from termcolor import colored
import jericho
import pickle
import random
from os.path import join as pjoin
from jericho.util import clean
from jericho.defines import BASIC_ACTIONS
import numpy as np
def get_detected_diffs(env):
interactive_objs = env._identify_interactive_objects(use_object_tree=True)
best_obj_names = env._score_object_names(interactive_objs)
candidate_actions = env.act_gen.generate_actions(best_obj_names)
diff2acts = env._filter_candidate_actions(candidate_actions, use_ctypes=True, use_parallel=True)
return diff2acts
def analyze_step(idx, env, gold_act):
"""
Checks to ensure the diff generated by the walkthrough action (gold_act) is amoung the diffs
considered by get_valid_actions.
"""
diff2acts = get_detected_diffs(env)
obs, _, _, _ = env.step(gold_act)
gold_diff = env._get_world_diff()
if gold_act.startswith('x ') or gold_act.startswith('examine ') or gold_act == 'z':
return
if not env._world_changed():
print(colored('{}. NoWorldChange gold_act: {}, obs: {}'.format(idx, gold_act, clean(obs)), 'magenta'))
if gold_diff not in diff2acts:
print(colored('{}. gold_act: {}-{} not in valids. Obs: {}'.format(idx, gold_act, gold_diff, clean(obs)), 'red'))
def run_walkthrough():
"""
Runs the walkthrough checking for two main things:
1) Is the world_diff from the walkthrough action among the diffs in get_valid_actions()?
2) Is the world state hash for the current state been encountered before or is unique?
"""
env = jericho.FrotzEnv(ROM)
obs, _ = env.reset()
walkthrough = env.get_walkthrough()
hashes = {}
for idx, gold_act in enumerate(walkthrough):
whash = env.get_world_state_hash()
if not whash in hashes:
hashes[whash] = [idx]
else:
print(colored('{}: {} States with same hash: {}'.format(idx, obs, hashes[whash]), 'cyan'))
hashes[whash].append(idx)
print(f'Step {idx} - {gold_act}')
analyze_step(idx, env, gold_act)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('rom', type=str, help='Rom File')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
global ROM
ROM = args.rom
run_walkthrough()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment