-
-
Save TadaoYamaoka/d5056c3470526e607e9845224ad97bb3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from cshogi import * | |
from cshogi import CSA | |
import numpy as np | |
import os | |
import glob | |
import re | |
import pandas as pd | |
from scipy.optimize import curve_fit | |
parser = argparse.ArgumentParser() | |
parser.add_argument('csa_dir') | |
parser.add_argument('--filter_moves', type=int, default=50) | |
parser.add_argument('--filter_rating', type=int, default=3800) | |
parser.add_argument('--filter_name') | |
args = parser.parse_args() | |
filter_moves = args.filter_moves | |
filter_rating = args.filter_rating | |
# 評価値から勝率への変換 | |
def score_to_value(score, a): | |
return 1.0 / (1.0 + np.exp(-score / a)) | |
scores = [] | |
results = [] | |
if args.filter_name: | |
ptn = re.compile(args.filter_name) | |
for filepath in glob.glob(os.path.join(args.csa_dir, '**', '*.csa'), recursive=True): | |
for kif in CSA.Parser.parse_file(filepath): | |
if kif.endgame not in ('%TORYO', '%SENNICHITE', '%KACHI') or len(kif.moves) < filter_moves: | |
continue | |
if filter_rating > 0 and (kif.ratings[0] < filter_rating and kif.ratings[1] < filter_rating): | |
continue | |
if args.filter_name: | |
kif_scores = [] | |
if ptn.search(kif.names[0]): | |
kif_scores.extend(kif.scores[0::2]) | |
if ptn.search(kif.names[1]): | |
kif_scores.extend(kif.scores[1::2]) | |
if len(kif_scores) == 0: | |
continue | |
else: | |
kif_scores = kif.scores | |
scores.extend(kif_scores) | |
if kif.win == DRAW: | |
results.extend([0.5] * len(kif_scores)) | |
else: | |
results.extend([2 - kif.win] * len(kif_scores)) | |
df = pd.DataFrame({'score': scores, 'result': results}) | |
# 評価値なしと詰みの局面を除く | |
df = df[(df['score'].abs() < 30000)&(df['score'] != 0)] | |
print(df['score'].describe()) | |
print(df['result'].describe()) | |
X = df['score'] | |
Y = df['result'] | |
popt, pcov = curve_fit(score_to_value, X, Y, p0=[600.0]) | |
print(popt) | |
print('score < 1000') | |
df1 = df[df['score'].abs() < 1000] | |
X = df1['score'] | |
Y = df1['result'] | |
popt, pcov = curve_fit(score_to_value, X, Y, p0=[600.0]) | |
print(popt) | |
print('score >= 1000') | |
df2 = df[df['score'].abs() >= 1000] | |
X = df2['score'] | |
Y = df2['result'] | |
popt, pcov = curve_fit(score_to_value, X, Y, p0=[600.0]) | |
print(popt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment