Skip to content

Instantly share code, notes, and snippets.

@wrongu
Last active May 10, 2016 12:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wrongu/35861c5967a0bd66f5e7b2f1e8483f55 to your computer and use it in GitHub Desktop.
Save wrongu/35861c5967a0bd66f5e7b2f1e8483f55 to your computer and use it in GitHub Desktop.
Script to diagnose differences in hdf5 features between RocAlphaGo's game converter and Thouis' sgf_to_hdf5 script. Usage: `python compare_go_features file1.hdf5 file2.hdf5 [comma,separated,feature,names]
import sys
import h5py
import pdb
import numpy as np
feature_sizes = {
"board": 3,
"ones": 1,
"turns_since": 8,
"liberties": 8,
"capture_size": 8,
"self_atari_size": 8,
"liberties_after": 8,
"ladder_capture": 1,
"ladder_escape": 1,
"sensibleness": 1,
"zeros": 1,
"legal": 1
}
ag_offsets = {
"board": 0,
"ones": 3,
"turns_since": 4,
"liberties": 12,
"capture_size": 20,
"self_atari_size": 28,
"liberties_after": 36,
# "ladder_capture": 44,
# "ladder_escape": 45,
"sensibleness": 44, # change to 46 when ladders added
"zeros": 45, # likewise, change to 47
}
pyf_offsets = {
"board": 0,
"ones": 3,
"turns_since": 4,
"liberties": 12,
"capture_size": 20,
"self_atari_size": 28,
"liberties_after": 36,
"ladder_capture": 44,
"ladder_escape": 45,
"sensibleness": 46,
"legal": 47,
}
shared_features = list(set(ag_offsets.keys()) & set(pyf_offsets.keys()))
def to_fuego_coord(x, y, sz=19):
return sz - y - 1, x
def compare_features(pyfuego_feats, alphago_feats, featureset=shared_features, break_on_error=False):
boardsize = pyfuego_feats.shape[2]
for move in range(len(pyfuego_feats)):
def disp_board():
markers = 'OX.'
print "SELF: {} OPPONENT: {} EMPTY: {}".format(*markers)
for y in range(boardsize):
for x in range(boardsize):
self_other_empty = alphago_feats[move, 0:3, x, y]
print markers[list(self_other_empty).index(1)],
print ""
for feature in featureset:
if feature in feature_sizes:
sz = feature_sizes[feature]
else:
print "unrecognized feature: {}".format(feature)
continue
if feature in ag_offsets:
ag_off = ag_offsets[feature]
else:
print "no {} in RocAlphaGo features".format(feature)
continue
if feature in pyf_offsets:
pyf_off = pyf_offsets[feature]
else:
print "no {} in PyFuego features".format(feature)
continue
for plane in range(sz):
# y coordinate is reversed in pyfuego, hence rot270
pyf_plane = np.rot90(pyfuego_feats[move, pyf_off + plane, ...], 3)
ag_plane = alphago_feats[move, ag_off + plane, ...]
if not np.all(ag_plane == pyf_plane):
print "mismatch [move {}] in plane {} of {}".format(move, plane, feature)
print "PyFuego version:"
print pyf_plane.transpose() # transpose because numpy prints row-wise
print "RocAlphaGo version:"
print ag_plane.transpose()
if sz > 1:
err_x, err_y = np.where(pyf_plane != ag_plane)
err_x = err_x[0]
err_y = err_y[0]
# note: 270-degree rotation is applied in x,y coordinates of pyfuego feature
rot_x, rot_y = to_fuego_coord(err_x, err_y)
print "PyFuego slice at ({},{}):\t{}".format(err_x, err_y, pyfuego_feats[move, pyf_off:pyf_off+sz, rot_x, rot_y])
print "RocAlphaGo slice at ({},{}):\t{}".format(err_x, err_y, alphago_feats[move, ag_off:ag_off+sz, err_x, err_y])
if break_on_error:
pdb.set_trace()
def pyfuego_gen(f):
names = f['gamefiles'][...]
starts = f['gameoffsets'][...]
ends = list(f['gameoffsets'][...])[1:] + [f['X'].shape[0]]
for n, s, e in sorted(zip(names, starts, ends)):
yield (n, f['X'][s:e, ...], f['y'][s:e, ...])
def ag_gen(f):
names = list(f['file_offsets'])
for n in sorted(names):
s, e = f['file_offsets'][n][...]
e = s + e
yield (n, f['states'][s:e, ...], f['actions'][s:e, ...])
ag_file = h5py.File(sys.argv[1], 'r')
pyf_file = h5py.File(sys.argv[2], 'r')
# swap if files input in other order
if 'X' in ag_file.keys():
tmp = ag_file
ag_file = pyf_file
pyf_file = tmp
kwargs = {'break_on_error' : True}
if len(sys.argv) > 3:
kwargs['featureset'] = sys.argv[3].split(',')
for (pyf_n, pyf_states, pyf_actions), (ag_n, ag_states, ag_actions) in zip(pyfuego_gen(pyf_file), ag_gen(ag_file)):
print pyf_n, "<>", ag_n
for move in range(len(pyf_states)):
# compare state
compare_features(pyf_states, ag_states, **kwargs)
# compare action
if not to_fuego_coord(*ag_actions[move, ...]) == tuple(pyf_actions[move, ...]):
print "action mismatch move {}".format(move)
if kwargs['break_on_error']:
pdb.set_trace()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment