Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import argparse
from cshogi import *
from cshogi import CSA
import numpy as np
import os
import glob
import re
import pandas as pd
from scipy.optimize import curve_fit
parser = argparse.ArgumentParser()
parser.add_argument('csa_dir')
parser.add_argument('--filter_moves', type=int, default=50)
parser.add_argument('--filter_rating', type=int, default=3800)
parser.add_argument('--filter_name')
args = parser.parse_args()
filter_moves = args.filter_moves
filter_rating = args.filter_rating
# 評価値から勝率への変換
def score_to_value(score, a):
return 1.0 / (1.0 + np.exp(-score / a))
scores = []
results = []
if args.filter_name:
ptn = re.compile(args.filter_name)
for filepath in glob.glob(os.path.join(args.csa_dir, '**', '*.csa'), recursive=True):
for kif in CSA.Parser.parse_file(filepath):
if kif.endgame not in ('%TORYO', '%SENNICHITE', '%KACHI') or len(kif.moves) < filter_moves:
continue
if filter_rating > 0 and (kif.ratings[0] < filter_rating and kif.ratings[1] < filter_rating):
continue
if args.filter_name:
kif_scores = []
if ptn.search(kif.names[0]):
kif_scores.extend(kif.scores[0::2])
if ptn.search(kif.names[1]):
kif_scores.extend(kif.scores[1::2])
if len(kif_scores) == 0:
continue
else:
kif_scores = kif.scores
scores.extend(kif_scores)
if kif.win == DRAW:
results.extend([0.5] * len(kif_scores))
else:
results.extend([2 - kif.win] * len(kif_scores))
df = pd.DataFrame({'score': scores, 'result': results})
# 評価値なしと詰みの局面を除く
df = df[(df['score'].abs() < 30000)&(df['score'] != 0)]
print(df['score'].describe())
print(df['result'].describe())
X = df['score']
Y = df['result']
popt, pcov = curve_fit(score_to_value, X, Y, p0=[600.0])
print(popt)
print('score < 1000')
df1 = df[df['score'].abs() < 1000]
X = df1['score']
Y = df1['result']
popt, pcov = curve_fit(score_to_value, X, Y, p0=[600.0])
print(popt)
print('score >= 1000')
df2 = df[df['score'].abs() >= 1000]
X = df2['score']
Y = df2['result']
popt, pcov = curve_fit(score_to_value, X, Y, p0=[600.0])
print(popt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment