Last active
January 16, 2023 15:17
-
-
Save TadaoYamaoka/652835b25721ad2028c0d4a0db149cab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from cshogi import * | |
from cshogi.dlshogi import make_input_features, FEATURES1_NUM, FEATURES2_NUM | |
import numpy as np | |
import onnxruntime | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('model', type=str, default='model', help='model file name') | |
parser.add_argument('sfen', type=str, help='position') | |
args = parser.parse_args() | |
session = onnxruntime.InferenceSession(args.model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
board = Board(sfen=args.sfen) | |
features1 = np.zeros((41, FEATURES1_NUM, 9, 9), dtype=np.float32) | |
features2 = np.zeros((41, FEATURES2_NUM, 9, 9), dtype=np.float32) | |
make_input_features(board, features1, features2) | |
board_src, turn, hand_src, ply = args.sfen.split(' ') | |
pos = [] | |
i = 1 | |
rank = 0 | |
file = 0 | |
digits = [str(j) for j in range(1, 10)] | |
j = 0 | |
while j < len(board_src): | |
if board_src[j] == '/': | |
rank += 1 | |
file = 0 | |
elif board_src[j] not in digits: | |
if board_src[j] == '+': | |
right = j + 2 | |
else: | |
right = j + 1 | |
if (j > 0 and board_src[j - 1] in digits) and (right < len(board_src) and board_src[right] in digits): | |
# 両側数字 | |
board_dst = board_src[:j - 1] + str(int(board_src[j - 1]) + int(board_src[right]) + 1) + board_src[right + 1:] | |
file += int(board_src[j - 1]) | |
elif j > 0 and board_src[j - 1] in digits: | |
# 左側数字 | |
board_dst = board_src[:j - 1] + str(int(board_src[j - 1]) + 1) + board_src[right:] | |
file += int(board_src[j - 1]) | |
elif right < len(board_src) and board_src[right] in digits: | |
# 右側数字 | |
board_dst = board_src[:j] + str(int(board_src[right]) + 1) + board_src[right + 1:] | |
else: | |
board_dst = board_src[:j] + '1' + board_src[right:] | |
board = Board(sfen=board_dst + ' ' + turn + ' ' + hand_src + ' ' + ply) | |
make_input_features(board, features1[i], features2[i]) | |
pos.append((file, rank, board_src[j:right])) | |
i += 1 | |
j = right - 1 | |
file += 1 | |
j += 1 | |
hand = [] | |
if hand_src != '-': | |
for j in range(len(hand_src)): | |
if hand_src[j] not in digits: | |
if j > 0 and hand_src[j - 1] in digits: | |
# 左側数字 | |
left = j - 1 | |
else: | |
left = j | |
hand_dst = hand_src[:left] + hand_src[j + 1:] | |
board = Board(sfen=board_src + ' ' + turn + ' ' + hand_dst + ' ' + ply) | |
make_input_features(board, features1[i], features2[i]) | |
hand.append(hand_src[j]) | |
i += 1 | |
io_binding = session.io_binding() | |
io_binding.bind_cpu_input('input1', features1) | |
io_binding.bind_cpu_input('input2', features2) | |
io_binding.bind_output('output_policy') | |
io_binding.bind_output('output_value') | |
session.run_with_iobinding(io_binding) | |
y1, y2 = io_binding.copy_outputs_to_cpu() | |
importance = y2 - y2[0] | |
output = [['' for _ in range(9)] for _ in range(9)] | |
for i in range(len(pos)): | |
file, rank, pt = pos[i] | |
output[rank][file] = format(float(importance[i + 1]), '.5f') | |
print('\n'.join(['\t'.join(row) for row in output])) | |
for i in range(len(hand)): | |
print(hand[i], format(float(importance[len(pos) + 1 + i]), '.5f'), sep='\t') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment