Skip to content

Instantly share code, notes, and snippets.

@yuguorui
Created January 4, 2018 03:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yuguorui/e7cacaee8aa19e7c7d337a6cd3cfdd8f to your computer and use it in GitHub Desktop.
Save yuguorui/e7cacaee8aa19e7c7d337a6cd3cfdd8f to your computer and use it in GitHub Desktop.
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger, DEBUG, StreamHandler, Formatter
from threading import Lock
import numpy as np
from chess_zero.config import Config
from chess_zero.env.chess_env import Winner
from chess_zero.env.chess_env import ChineseChessEnv
from chess_zero.agent import chinese_chess
logger = getLogger(__name__)
logger.setLevel(DEBUG)
# create console handler and set level to debug
ch = StreamHandler()
ch.setLevel(DEBUG)
formatter = Formatter('%(asctime)s - %(thread)d - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
# these are from AGZ nature paper
class VisitStats:
def __init__(self):
self.a = defaultdict(ActionStats)
self.sum_n = 0
self.p = None
class ActionStats:
def __init__(self):
self.n = 0
self.w = 0
self.q = 0
class ChineseChessPlayer:
# dot = False
def __init__(self, config: Config, pipes=None, play_config=None, dummy=False):
self.moves = []
self.config = config
self.play_config = play_config or self.config.play
self.labels_n = config.n_labels
self.labels = config.labels
self.move_lookup = {
chinese_chess.Move.from_ucci(move): i
for move, i in zip(self.labels, range(self.labels_n))
}
if dummy:
return
self.pipe_pool = pipes
self.node_lock = defaultdict(Lock)
def reset_mcts(self):
self.tree = defaultdict(VisitStats)
def deboog(self, env):
print(env.testeval())
state = state_key(env)
my_visit_stats = self.tree[state]
stats = []
for action, a_s in my_visit_stats.a.items():
moi = self.move_lookup[action]
stats.append(np.asarray([a_s.n, a_s.w, a_s.q, a_s.p, moi]))
stats = np.asarray(stats)
a = stats[stats[:, 0].argsort()[::-1]]
for s in a:
print(f'{self.labels[int(s[4])]:5}: '
f'n: {s[0]:3.0f} '
f'w: {s[1]:7.3f} '
f'q: {s[2]:7.3f} '
f'p: {s[3]:7.5f}')
def action(self, env, can_stop=True) -> str:
self.reset_mcts()
# for tl in range(self.play_config.thinking_loop):
root_value, naked_value = self.search_moves(env)
policy = self.calc_policy(env)
my_action = int(np.random.choice(range(self.labels_n), p=self.apply_temperature(policy, env.num_halfmoves)))
if can_stop and self.play_config.resign_threshold is not None and \
root_value <= self.play_config.resign_threshold \
and env.num_halfmoves > self.play_config.min_resign_turn:
# noinspection PyTypeChecker
return None
else:
self.moves.append([env.observation, list(policy)])
return self.config.labels[my_action]
def search_moves(self, env) -> (float, float):
futures = []
with ThreadPoolExecutor(max_workers=self.play_config.search_threads) as executor:
for _ in range(self.play_config.simulation_num_per_move):
futures.append(executor.submit(self.search_my_move, env=env.copy(), is_root_node=True))
vals = [f.result() for f in futures]
# v = self.search_my_move(env.copy(), True)
# vals = [v]
return np.max(vals), vals[0] # vals[0] is kind of racy
def search_my_move(self, env: ChineseChessEnv, is_root_node=False) -> float:
"""
Q, V is value for this Player(always white).
P is value for the player of next_player (black or white)
:return: leaf value
"""
if env.done:
if env.winner == Winner.draw:
return 0
# assert env.whitewon != env.white_to_move # side to move can't be winner!
return -1
state = state_key(env)
with self.node_lock[state]:
if state not in self.tree:
leaf_p, leaf_v = self.expand_and_evaluate(env)
self.tree[state].p = leaf_p
return leaf_v # I'm returning everything from the POV of side to move
# SELECT STEP
action_t = self.select_action_q_and_u(env, is_root_node)
virtual_loss = self.play_config.virtual_loss
my_visit_stats = self.tree[state]
my_stats = my_visit_stats.a[action_t]
my_visit_stats.sum_n += virtual_loss
my_stats.n += virtual_loss
my_stats.w += -virtual_loss
my_stats.q = my_stats.w / my_stats.n
env.step(action_t.ucci())
leaf_v = self.search_my_move(env) # next move from enemy POV
leaf_v = -leaf_v
# BACKUP STEP
# on returning search path
# update: N, W, Q
with self.node_lock[state]:
my_visit_stats.sum_n += -virtual_loss + 1
my_stats.n += -virtual_loss + 1
my_stats.w += virtual_loss + leaf_v
my_stats.q = my_stats.w / my_stats.n
return leaf_v
def expand_and_evaluate(self, env: ChineseChessEnv) -> (np.ndarray, float):
""" expand new leaf, this is called only once per state
this is called with state locked
insert P(a|s), return leaf_v
"""
state_planes = env.canonical_input_planes()
leaf_p, leaf_v = self.predict(state_planes)
# these are canonical policy and value (i.e. side to move is "white")
if not env.white_to_move:
leaf_p = Config.flip_policy(leaf_p) # get it back to python-chess form
return leaf_p, leaf_v
def predict(self, state_planes):
pipe = self.pipe_pool.pop()
pipe.send(state_planes)
ret = pipe.recv()
self.pipe_pool.append(pipe)
return ret
# @profile
def select_action_q_and_u(self, env, is_root_node) -> chinese_chess.Move:
# this method is called with state locked
state = state_key(env)
my_visitstats = self.tree[state]
if not hasattr(my_visitstats, 'p'):
pass
if my_visitstats.p is not None: # push p to edges
tot_p = 1e-8
for mov in env.board.legal_moves:
logger.debug(f'Move: {mov}')
# print(f'Move: {mov}')
mov_p = my_visitstats.p[self.move_lookup[mov]]
my_visitstats.a[mov].p = mov_p
tot_p += mov_p
for a_s in my_visitstats.a.values():
a_s.p /= tot_p
my_visitstats.p = None
# U(s,a)分式中的分子部分
xx_ = np.sqrt(my_visitstats.sum_n + 1) # sqrt of sum(N(s, b); for all b)
e = self.play_config.noise_eps
c_puct = self.play_config.c_puct
dir_alpha = self.play_config.dirichlet_alpha
best_s = -999
best_action = None
# argmax(Q(s_t, a) + U(s_t, a))
# 对每个动作计算a_t,选取收益最高的action
for action, a_s in my_visitstats.a.items():
p_ = a_s.p
if is_root_node:
p_ = (1 - e) * p_ + e * np.random.dirichlet([dir_alpha])
b = a_s.q + c_puct * p_ * xx_ / (1 + a_s.n)
if b > best_s:
best_s = b
best_action = action
assert best_action is not None
return best_action
def apply_temperature(self, policy, turn):
tau = np.power(self.play_config.tau_decay_rate, turn + 1)
if tau < 0.1:
tau = 0
if tau == 0:
action = np.argmax(policy)
ret = np.zeros(self.labels_n)
ret[action] = 1.0
return ret
else:
ret = np.power(policy, 1 / tau)
ret /= np.sum(ret)
return ret
def calc_policy(self, env):
"""calc π(a|s0)
:return:
"""
state = state_key(env)
my_visitstats = self.tree[state]
policy = np.zeros(self.labels_n)
for action, a_s in my_visitstats.a.items():
policy[self.move_lookup[action]] = a_s.n
policy /= np.sum(policy)
return policy
def sl_action(self, observation, my_action, weight=1):
policy = np.zeros(self.labels_n)
k = self.move_lookup[chinese_chess.Move.from_ucci(my_action)]
policy[k] = weight
self.moves.append([observation, list(policy)])
return my_action
def finish_game(self, z):
"""
:param self:
:param z: win=1, lose=-1, draw=0
:return:
"""
for move in self.moves: # add this game winner result to all past moves.
move += [z]
def state_key(env: ChineseChessEnv) -> str:
fen = env.board.fen().rsplit(' ', 1) # drop the move clock
return fen[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment