Skip to content

Instantly share code, notes, and snippets.

@AndyGrant
Created October 2, 2023 11:43
Show Gist options
  • Save AndyGrant/9a0aaf58c9bc7fc6ecde339f23cf9938 to your computer and use it in GitHub Desktop.
Save AndyGrant/9a0aaf58c9bc7fc6ecde339f23cf9938 to your computer and use it in GitHub Desktop.
#!/bin/python3
import argparse
import chess
import chess.pgn
import chess.syzygy
import io
import multiprocessing
import numpy as np
import sys
import traceback
np.set_printoptions(suppress=True)
p = argparse.ArgumentParser()
p.add_argument('pgn_path', help='Path to pgn file')
p.add_argument('-we', '--win-eval', help='In terms of centipawns', default=600, type=int)
p.add_argument('-ws', '--win-streak', help='In terms of full moves (2 ply)', default=5, type=int)
p.add_argument('-de', '--draw-eval', help='In terms of centipawns', default=15, type=int)
p.add_argument('-ds', '--draw-streak', help='In terms of full moves (2 ply)', default=6, type=int)
p.add_argument('-dm', '--draw-movecount', help='In terms of full moves (2 ply)', default=32, type=int)
p.add_argument('-dr', '--draw-reset-on-zeroing', help='Reset draw streak on zeroing move', default=True, type=bool)
p.add_argument('--syzygy-path', help='Location of Syzygy WDL files', default='', type=str)
p.add_argument('--syzygy-limit', help='Depth to use when probing', default=0, type=int)
args = p.parse_args()
if args.syzygy_path:
TABLEBASES = chess.syzygy.open_tablebase(args.syzygy_path)
WIN = 0
DRAW = 1
LOSS = 2
TB_WIN = 2 # Relative to Side to Move
TB_DRAW = 0 # Relative to Side to Move
TB_LOSS = -2 # Relative to Side to Move
MATE_SCORE = 32000
MATE_BOUND = 31000
class DataCollector(chess.pgn.BaseVisitor):
def begin_game(self):
self.scores = []
self.zeroing = []
self.full_moves = []
self.fens = []
def visit_comment(self, comment):
self.scores.append(self.parse_comment(comment))
def visit_move(self, board, move):
zeroing = board.is_zeroing(move)
self.zeroing.append(zeroing)
self.full_moves.append(board.fullmove_number)
self.fens.append(board.fen() if zeroing else None)
def result(self):
return zip(self.scores, self.zeroing, self.full_moves, self.fens)
def parse_comment(self, comment):
# Comments look like { +0.12/10 ... }
if '/' not in comment:
return None
text = comment.split('/')[0]
if '-M' in text:
return (-MATE_SCORE + int(text[2:]))
if '+M' in text:
return (+MATE_SCORE - int(text[2:]))
return int(float(text) * 100)
def get_next_game_str(file_name):
with open(file_name) as pgn_file:
game = ''; count = 0
for line in pgn_file:
game += line
count += line.strip() == ''
if count == 2:
yield game
game = ''; count = 0
def retroactively_adjudicate(game):
win_streak, draw_streak, loss_streak = (0, 0, 0)
whites_first_ply = 0 if game.turn() == chess.WHITE else 1
for ply, data in enumerate(game.accept(DataCollector())):
score, zeroing, full_move, fen = data
if zeroing and args.syzygy_path:
board = chess.Board(fen)
if args.syzygy_limit == len(board.piece_map()):
wdl = TABLEBASES.probe_wdl(chess.Board(fen))
# Convert to White's POV
if ply % 2 != whites_first_ply:
wdl = -wdl
if wdl == TB_WIN: return ply, WIN
if wdl == TB_DRAW: return ply, DRAW
if wdl == TB_LOSS: return ply, LOSS
# It is possible for an engine to report nothing, and if so we cannot do anything
if score == None:
continue
# Convert Black's scores to a White POV
if ply % 2 != whites_first_ply:
score = -score;
# If we don't care about zeroing, then pretend it is False
if not args.draw_reset_on_zeroing:
zeroing = False
# Determine if any of the Adjudication streak conditions were met during this ply
suggests_win = score >= args.win_eval
suggests_draw = abs(score) < args.draw_eval and not zeroing and full_move >= args.draw_movecount
suggests_loss = score <= -args.win_eval
# Update streak counters based on what this move suggested
win_streak = win_streak + 1 if suggests_win else 0
draw_streak = draw_streak + 1 if suggests_draw else 0
loss_streak = loss_streak + 1 if suggests_loss else 0
# Apply adjudications if final criteria was met
if win_streak >= 2 * args.win_streak : return ply, WIN
if draw_streak >= 2 * args.draw_streak: return ply, DRAW
if loss_streak >= 2 * args.win_streak : return ply, LOSS
return ply, None # No Adjudication was done
def process_game(game):
result_to_idx = { '1-0' : WIN, '1/2-1/2' : DRAW, '0-1' : LOSS }
game = chess.pgn.read_game(io.StringIO(game))
total_ply = int(game.headers['PlyCount'])
actual_result = result_to_idx[game.headers['Result']]
played_ply, predicted_result = retroactively_adjudicate(game)
return (actual_result, predicted_result, total_ply, played_ply)
if __name__ == '__main__':
print ('Loading individual games from %s... ' % args.pgn_path, end='')
# Split the PGN into individual games as strings
games = []
for game in get_next_game_str(args.pgn_path):
games.append(game)
num_games = len(games)
print ('Found %s games\n' % num_games)
# Spawn more processes to analyze each game
pool = multiprocessing.Pool()
data = pool.map(process_game, games)
plys_total = np.array([0, 0, 0])
plys_viewed = np.array([0, 0, 0])
games_viewed = np.array([0, 0, 0])
results = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
for actual_result, predicted_result, total_ply, played_ply in data:
plys_total [actual_result] += total_ply
games_viewed[actual_result] += 1
plys_viewed [actual_result] += played_ply
if predicted_result != None:
results[predicted_result][actual_result] += 1
ply_savings_per = (1 - plys_viewed / (plys_total + 1))
ply_savings_sum = np.sum(ply_savings_per * (games_viewed / num_games))
print('Percentage of actual result distribution : ', games_viewed / num_games)
print('Percentage of ply saved via adjudication : ', ply_savings_per)
print('Percentage of ply pruned in all total : ', ply_savings_sum)
print('\n\nTable of results, [Predicted][Actual]\n')
results = results.astype(np.float64)
for row in results:
row /= np.sum(row) if np.sum(row) != 0 else 1
print (results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment