Skip to content

Instantly share code, notes, and snippets.

@g-leech
Last active April 17, 2018 10:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save g-leech/d4caf8e0fa8f187114bac91a31e69f92 to your computer and use it in GitHub Desktop.
Save g-leech/d4caf8e0fa8f187114bac91a31e69f92 to your computer and use it in GitHub Desktop.
from ai_safety_gridworlds.environments.shared.safety_game import Actions
from hashlib import sha1
import numpy
import time
from IPython import display
import copy
env = sokoban_game(level=0)
ACTIONS = [ a for a in Actions if a is not Actions.QUIT ]
def get_frame(step, x, y):
color_state = step.observation['RGB']
return np.moveaxis(color_state, x, y)
plt.ion()
plt.axis('off')
time_step = env.reset()
im = plt.imshow(get_frame(time_step, 0, -1), animated=True)
def refresh_screen(step, x=0, y=-1):
time.sleep(0.6)
frame = get_frame(step, x, y)
im.set_data(frame)
display.clear_output(wait=True)
display.display(plt.gcf())
def merge_two_dicts(x, y):
z = x.copy()
z.update(y)
return z
def hash_board(state) :
return sha1(state.observation['RGB']).hexdigest()
initialState = env.reset()
"""
Returns
* transitions, a dict from (state, nextState) to 1 or 0.
* hashMap, a dict from sha code to RGB state
"""
def crawl_for_transitions(envir, lastState, hashMap) :
transitions = {}
for action in ACTIONS:
frozenEnv = copy.deepcopy(envir)
nextState = frozenEnv.step(action)
lastIndex = hash_board(lastState)
index = hash_board(nextState)
if index not in hashMap :
hashMap[index] = nextState.observation['RGB']
#refresh_screen(nextState)
if not lastIndex == index :
transitions[(lastIndex, action, index)] = 1
lastState = nextState
subtransitions, subhashes = crawl_for_transitions(frozenEnv, lastState, hashMap)
transitions = merge_two_dicts(subtransitions, transitions)
hashMap = merge_two_dicts(subhashes, hashMap)
else :
transitions[(lastIndex, action, index)] = 0
return transitions, hashMap
hashMap = {}
transitions, hashMap = crawl_for_transitions(env, initialState, hashMap)
print(transitions)
len(hashMap)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment