Skip to content

Instantly share code, notes, and snippets.

@TadaoYamaoka
Last active January 16, 2023 15:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TadaoYamaoka/652835b25721ad2028c0d4a0db149cab to your computer and use it in GitHub Desktop.
Save TadaoYamaoka/652835b25721ad2028c0d4a0db149cab to your computer and use it in GitHub Desktop.
from cshogi import *
from cshogi.dlshogi import make_input_features, FEATURES1_NUM, FEATURES2_NUM
import numpy as np
import onnxruntime
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, default='model', help='model file name')
parser.add_argument('sfen', type=str, help='position')
args = parser.parse_args()
session = onnxruntime.InferenceSession(args.model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
board = Board(sfen=args.sfen)
features1 = np.zeros((41, FEATURES1_NUM, 9, 9), dtype=np.float32)
features2 = np.zeros((41, FEATURES2_NUM, 9, 9), dtype=np.float32)
make_input_features(board, features1, features2)
board_src, turn, hand_src, ply = args.sfen.split(' ')
pos = []
i = 1
rank = 0
file = 0
digits = [str(j) for j in range(1, 10)]
j = 0
while j < len(board_src):
if board_src[j] == '/':
rank += 1
file = 0
elif board_src[j] not in digits:
if board_src[j] == '+':
right = j + 2
else:
right = j + 1
if (j > 0 and board_src[j - 1] in digits) and (right < len(board_src) and board_src[right] in digits):
# 両側数字
board_dst = board_src[:j - 1] + str(int(board_src[j - 1]) + int(board_src[right]) + 1) + board_src[right + 1:]
file += int(board_src[j - 1])
elif j > 0 and board_src[j - 1] in digits:
# 左側数字
board_dst = board_src[:j - 1] + str(int(board_src[j - 1]) + 1) + board_src[right:]
file += int(board_src[j - 1])
elif right < len(board_src) and board_src[right] in digits:
# 右側数字
board_dst = board_src[:j] + str(int(board_src[right]) + 1) + board_src[right + 1:]
else:
board_dst = board_src[:j] + '1' + board_src[right:]
board = Board(sfen=board_dst + ' ' + turn + ' ' + hand_src + ' ' + ply)
make_input_features(board, features1[i], features2[i])
pos.append((file, rank, board_src[j:right]))
i += 1
j = right - 1
file += 1
j += 1
hand = []
if hand_src != '-':
for j in range(len(hand_src)):
if hand_src[j] not in digits:
if j > 0 and hand_src[j - 1] in digits:
# 左側数字
left = j - 1
else:
left = j
hand_dst = hand_src[:left] + hand_src[j + 1:]
board = Board(sfen=board_src + ' ' + turn + ' ' + hand_dst + ' ' + ply)
make_input_features(board, features1[i], features2[i])
hand.append(hand_src[j])
i += 1
io_binding = session.io_binding()
io_binding.bind_cpu_input('input1', features1)
io_binding.bind_cpu_input('input2', features2)
io_binding.bind_output('output_policy')
io_binding.bind_output('output_value')
session.run_with_iobinding(io_binding)
y1, y2 = io_binding.copy_outputs_to_cpu()
importance = y2 - y2[0]
output = [['' for _ in range(9)] for _ in range(9)]
for i in range(len(pos)):
file, rank, pt = pos[i]
output[rank][file] = format(float(importance[i + 1]), '.5f')
print('\n'.join(['\t'.join(row) for row in output]))
for i in range(len(hand)):
print(hand[i], format(float(importance[len(pos) + 1 + i]), '.5f'), sep='\t')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment