Last active
May 10, 2016 12:59
-
-
Save wrongu/35861c5967a0bd66f5e7b2f1e8483f55 to your computer and use it in GitHub Desktop.
Script to diagnose differences in hdf5 features between RocAlphaGo's game converter and Thouis' sgf_to_hdf5 script. Usage: `python compare_go_features file1.hdf5 file2.hdf5 [comma,separated,feature,names]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import h5py | |
import pdb | |
import numpy as np | |
feature_sizes = { | |
"board": 3, | |
"ones": 1, | |
"turns_since": 8, | |
"liberties": 8, | |
"capture_size": 8, | |
"self_atari_size": 8, | |
"liberties_after": 8, | |
"ladder_capture": 1, | |
"ladder_escape": 1, | |
"sensibleness": 1, | |
"zeros": 1, | |
"legal": 1 | |
} | |
ag_offsets = { | |
"board": 0, | |
"ones": 3, | |
"turns_since": 4, | |
"liberties": 12, | |
"capture_size": 20, | |
"self_atari_size": 28, | |
"liberties_after": 36, | |
# "ladder_capture": 44, | |
# "ladder_escape": 45, | |
"sensibleness": 44, # change to 46 when ladders added | |
"zeros": 45, # likewise, change to 47 | |
} | |
pyf_offsets = { | |
"board": 0, | |
"ones": 3, | |
"turns_since": 4, | |
"liberties": 12, | |
"capture_size": 20, | |
"self_atari_size": 28, | |
"liberties_after": 36, | |
"ladder_capture": 44, | |
"ladder_escape": 45, | |
"sensibleness": 46, | |
"legal": 47, | |
} | |
shared_features = list(set(ag_offsets.keys()) & set(pyf_offsets.keys())) | |
def to_fuego_coord(x, y, sz=19): | |
return sz - y - 1, x | |
def compare_features(pyfuego_feats, alphago_feats, featureset=shared_features, break_on_error=False): | |
boardsize = pyfuego_feats.shape[2] | |
for move in range(len(pyfuego_feats)): | |
def disp_board(): | |
markers = 'OX.' | |
print "SELF: {} OPPONENT: {} EMPTY: {}".format(*markers) | |
for y in range(boardsize): | |
for x in range(boardsize): | |
self_other_empty = alphago_feats[move, 0:3, x, y] | |
print markers[list(self_other_empty).index(1)], | |
print "" | |
for feature in featureset: | |
if feature in feature_sizes: | |
sz = feature_sizes[feature] | |
else: | |
print "unrecognized feature: {}".format(feature) | |
continue | |
if feature in ag_offsets: | |
ag_off = ag_offsets[feature] | |
else: | |
print "no {} in RocAlphaGo features".format(feature) | |
continue | |
if feature in pyf_offsets: | |
pyf_off = pyf_offsets[feature] | |
else: | |
print "no {} in PyFuego features".format(feature) | |
continue | |
for plane in range(sz): | |
# y coordinate is reversed in pyfuego, hence rot270 | |
pyf_plane = np.rot90(pyfuego_feats[move, pyf_off + plane, ...], 3) | |
ag_plane = alphago_feats[move, ag_off + plane, ...] | |
if not np.all(ag_plane == pyf_plane): | |
print "mismatch [move {}] in plane {} of {}".format(move, plane, feature) | |
print "PyFuego version:" | |
print pyf_plane.transpose() # transpose because numpy prints row-wise | |
print "RocAlphaGo version:" | |
print ag_plane.transpose() | |
if sz > 1: | |
err_x, err_y = np.where(pyf_plane != ag_plane) | |
err_x = err_x[0] | |
err_y = err_y[0] | |
# note: 270-degree rotation is applied in x,y coordinates of pyfuego feature | |
rot_x, rot_y = to_fuego_coord(err_x, err_y) | |
print "PyFuego slice at ({},{}):\t{}".format(err_x, err_y, pyfuego_feats[move, pyf_off:pyf_off+sz, rot_x, rot_y]) | |
print "RocAlphaGo slice at ({},{}):\t{}".format(err_x, err_y, alphago_feats[move, ag_off:ag_off+sz, err_x, err_y]) | |
if break_on_error: | |
pdb.set_trace() | |
def pyfuego_gen(f): | |
names = f['gamefiles'][...] | |
starts = f['gameoffsets'][...] | |
ends = list(f['gameoffsets'][...])[1:] + [f['X'].shape[0]] | |
for n, s, e in sorted(zip(names, starts, ends)): | |
yield (n, f['X'][s:e, ...], f['y'][s:e, ...]) | |
def ag_gen(f): | |
names = list(f['file_offsets']) | |
for n in sorted(names): | |
s, e = f['file_offsets'][n][...] | |
e = s + e | |
yield (n, f['states'][s:e, ...], f['actions'][s:e, ...]) | |
ag_file = h5py.File(sys.argv[1], 'r') | |
pyf_file = h5py.File(sys.argv[2], 'r') | |
# swap if files input in other order | |
if 'X' in ag_file.keys(): | |
tmp = ag_file | |
ag_file = pyf_file | |
pyf_file = tmp | |
kwargs = {'break_on_error' : True} | |
if len(sys.argv) > 3: | |
kwargs['featureset'] = sys.argv[3].split(',') | |
for (pyf_n, pyf_states, pyf_actions), (ag_n, ag_states, ag_actions) in zip(pyfuego_gen(pyf_file), ag_gen(ag_file)): | |
print pyf_n, "<>", ag_n | |
for move in range(len(pyf_states)): | |
# compare state | |
compare_features(pyf_states, ag_states, **kwargs) | |
# compare action | |
if not to_fuego_coord(*ag_actions[move, ...]) == tuple(pyf_actions[move, ...]): | |
print "action mismatch move {}".format(move) | |
if kwargs['break_on_error']: | |
pdb.set_trace() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment