Created
October 2, 2023 11:43
-
-
Save AndyGrant/9a0aaf58c9bc7fc6ecde339f23cf9938 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/python3 | |
import argparse | |
import chess | |
import chess.pgn | |
import chess.syzygy | |
import io | |
import multiprocessing | |
import numpy as np | |
import sys | |
import traceback | |
np.set_printoptions(suppress=True) | |
p = argparse.ArgumentParser() | |
p.add_argument('pgn_path', help='Path to pgn file') | |
p.add_argument('-we', '--win-eval', help='In terms of centipawns', default=600, type=int) | |
p.add_argument('-ws', '--win-streak', help='In terms of full moves (2 ply)', default=5, type=int) | |
p.add_argument('-de', '--draw-eval', help='In terms of centipawns', default=15, type=int) | |
p.add_argument('-ds', '--draw-streak', help='In terms of full moves (2 ply)', default=6, type=int) | |
p.add_argument('-dm', '--draw-movecount', help='In terms of full moves (2 ply)', default=32, type=int) | |
p.add_argument('-dr', '--draw-reset-on-zeroing', help='Reset draw streak on zeroing move', default=True, type=bool) | |
p.add_argument('--syzygy-path', help='Location of Syzygy WDL files', default='', type=str) | |
p.add_argument('--syzygy-limit', help='Depth to use when probing', default=0, type=int) | |
args = p.parse_args() | |
if args.syzygy_path: | |
TABLEBASES = chess.syzygy.open_tablebase(args.syzygy_path) | |
WIN = 0 | |
DRAW = 1 | |
LOSS = 2 | |
TB_WIN = 2 # Relative to Side to Move | |
TB_DRAW = 0 # Relative to Side to Move | |
TB_LOSS = -2 # Relative to Side to Move | |
MATE_SCORE = 32000 | |
MATE_BOUND = 31000 | |
class DataCollector(chess.pgn.BaseVisitor): | |
def begin_game(self): | |
self.scores = [] | |
self.zeroing = [] | |
self.full_moves = [] | |
self.fens = [] | |
def visit_comment(self, comment): | |
self.scores.append(self.parse_comment(comment)) | |
def visit_move(self, board, move): | |
zeroing = board.is_zeroing(move) | |
self.zeroing.append(zeroing) | |
self.full_moves.append(board.fullmove_number) | |
self.fens.append(board.fen() if zeroing else None) | |
def result(self): | |
return zip(self.scores, self.zeroing, self.full_moves, self.fens) | |
def parse_comment(self, comment): | |
# Comments look like { +0.12/10 ... } | |
if '/' not in comment: | |
return None | |
text = comment.split('/')[0] | |
if '-M' in text: | |
return (-MATE_SCORE + int(text[2:])) | |
if '+M' in text: | |
return (+MATE_SCORE - int(text[2:])) | |
return int(float(text) * 100) | |
def get_next_game_str(file_name): | |
with open(file_name) as pgn_file: | |
game = ''; count = 0 | |
for line in pgn_file: | |
game += line | |
count += line.strip() == '' | |
if count == 2: | |
yield game | |
game = ''; count = 0 | |
def retroactively_adjudicate(game): | |
win_streak, draw_streak, loss_streak = (0, 0, 0) | |
whites_first_ply = 0 if game.turn() == chess.WHITE else 1 | |
for ply, data in enumerate(game.accept(DataCollector())): | |
score, zeroing, full_move, fen = data | |
if zeroing and args.syzygy_path: | |
board = chess.Board(fen) | |
if args.syzygy_limit == len(board.piece_map()): | |
wdl = TABLEBASES.probe_wdl(chess.Board(fen)) | |
# Convert to White's POV | |
if ply % 2 != whites_first_ply: | |
wdl = -wdl | |
if wdl == TB_WIN: return ply, WIN | |
if wdl == TB_DRAW: return ply, DRAW | |
if wdl == TB_LOSS: return ply, LOSS | |
# It is possible for an engine to report nothing, and if so we cannot do anything | |
if score == None: | |
continue | |
# Convert Black's scores to a White POV | |
if ply % 2 != whites_first_ply: | |
score = -score; | |
# If we don't care about zeroing, then pretend it is False | |
if not args.draw_reset_on_zeroing: | |
zeroing = False | |
# Determine if any of the Adjudication streak conditions were met during this ply | |
suggests_win = score >= args.win_eval | |
suggests_draw = abs(score) < args.draw_eval and not zeroing and full_move >= args.draw_movecount | |
suggests_loss = score <= -args.win_eval | |
# Update streak counters based on what this move suggested | |
win_streak = win_streak + 1 if suggests_win else 0 | |
draw_streak = draw_streak + 1 if suggests_draw else 0 | |
loss_streak = loss_streak + 1 if suggests_loss else 0 | |
# Apply adjudications if final criteria was met | |
if win_streak >= 2 * args.win_streak : return ply, WIN | |
if draw_streak >= 2 * args.draw_streak: return ply, DRAW | |
if loss_streak >= 2 * args.win_streak : return ply, LOSS | |
return ply, None # No Adjudication was done | |
def process_game(game): | |
result_to_idx = { '1-0' : WIN, '1/2-1/2' : DRAW, '0-1' : LOSS } | |
game = chess.pgn.read_game(io.StringIO(game)) | |
total_ply = int(game.headers['PlyCount']) | |
actual_result = result_to_idx[game.headers['Result']] | |
played_ply, predicted_result = retroactively_adjudicate(game) | |
return (actual_result, predicted_result, total_ply, played_ply) | |
if __name__ == '__main__': | |
print ('Loading individual games from %s... ' % args.pgn_path, end='') | |
# Split the PGN into individual games as strings | |
games = [] | |
for game in get_next_game_str(args.pgn_path): | |
games.append(game) | |
num_games = len(games) | |
print ('Found %s games\n' % num_games) | |
# Spawn more processes to analyze each game | |
pool = multiprocessing.Pool() | |
data = pool.map(process_game, games) | |
plys_total = np.array([0, 0, 0]) | |
plys_viewed = np.array([0, 0, 0]) | |
games_viewed = np.array([0, 0, 0]) | |
results = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) | |
for actual_result, predicted_result, total_ply, played_ply in data: | |
plys_total [actual_result] += total_ply | |
games_viewed[actual_result] += 1 | |
plys_viewed [actual_result] += played_ply | |
if predicted_result != None: | |
results[predicted_result][actual_result] += 1 | |
ply_savings_per = (1 - plys_viewed / (plys_total + 1)) | |
ply_savings_sum = np.sum(ply_savings_per * (games_viewed / num_games)) | |
print('Percentage of actual result distribution : ', games_viewed / num_games) | |
print('Percentage of ply saved via adjudication : ', ply_savings_per) | |
print('Percentage of ply pruned in all total : ', ply_savings_sum) | |
print('\n\nTable of results, [Predicted][Actual]\n') | |
results = results.astype(np.float64) | |
for row in results: | |
row /= np.sum(row) if np.sum(row) != 0 else 1 | |
print (results) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment