Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active July 18, 2019 00:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kastnerkyle/7f1038cdd1191cdee71613e65d847122 to your computer and use it in GitHub Desktop.
Save kastnerkyle/7f1038cdd1191cdee71613e65d847122 to your computer and use it in GitHub Desktop.
vloss MCTS for single player
# Author: Kyle Kastner
# License: BSD 3-Clause
# based on minigo implementation
# https://github.com/tensorflow/minigo/blob/master/mcts.py
# Useful discussion of the benefits
# http://www.moderndescartes.com/essays/agz/
# single player tweaks based on
# https://tmoer.github.io/AlphaZero/
# See survey
# http://mcts.ai/pubs/mcts-survey-master.pdf
# See similar implementation here
# https://github.com/junxiaosong/AlphaZero_Gomoku
# some changes from high level pseudo-code in survey
import numpy as np
import copy
import collections
def softmax(x):
assert len(x.shape) == 1
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
class CountingManager(object):
def __init__(self, size=10, rollout_limit=1000):
self.size = size
self.random_state = np.random.RandomState(1999)
self.rollout_limit = rollout_limit
def get_next_state(self, state, action):
if action == state:
return state + 1
else:
return 0
def get_current_player(self, state):
return 0
def get_action_space(self):
return list(range(self.size))
def get_valid_actions(self, state):
return list(range(self.size))
def get_init_state(self):
return 0
def _rollout_fn(self, state):
return self.random_state.choice(self.get_valid_actions(state))
def rollout_from_state(self, state):
s = state
w, e = self.is_finished(s)
if e:
return 1.
c = 0
while True:
a = self._rollout_fn(s)
s = self.get_next_state(s, a)
w, e = self.is_finished(s)
c += 1
if e:
return 1. / float(c)
if c > self.rollout_limit:
return 0.
def is_finished(self, state):
# returns 1 if in terminal state, 0 otherwise
if state == self.size - 1:
return 1, True
else:
return 0, False
class EmptyNode(object):
"""Empty node of MCTS tree, placeholder for root.
Code becomes simpler if all nodes have parents"""
def __init__(self):
self.parent = None
self.child_N = collections.defaultdict(float)
self.child_W = collections.defaultdict(float)
class MCTSNode(object):
"""Node of MCTS tree, can compute action scores of all children
state_manager: a state manager instance which correspnds to this node
parent: A parent MCTSNode (None means this is a "first action" node
action_to_here: action that led to this node, usually an integer but depends on state_manager implementation (see get_action_space / get_valid_actions)
"""
def __init__(self, state_manager, parent=None, single_player=False, q_init=None, current_state=None, action_to_here=None, n_playout=1000, random_state=None):
# this assumes full observability?
if parent is None:
parent = EmptyNode()
action_to_here = None
current_state = state_manager.get_init_state()
self.single_player = single_player
self.q_init = q_init
if single_player and q_init is None:
raise ValueError("Single player nodes require q_init argument to be passed")
self.state_manager = state_manager
self.current_state = current_state
self.valid_actions = self.state_manager.get_valid_actions(self.current_state)
self.illegal_moves = np.array([0. if n in self.valid_actions else 1. for n in range(len(self.state_manager.get_action_space()))])
self.parent = parent
self.action_to_here = action_to_here
self.c_puct = 1.4
self.n_playout = n_playout
self.warn_at_ = 10000
if random_state is None:
raise ValueError("Must pass random_state object")
self.random_state = random_state
self.is_expanded = False
self.losses_applied = 0
# child_() allows vectorized computation of action score
self.child_N = np.zeros([len(self.state_manager.get_action_space())], dtype=np.float32)
self.child_W = np.zeros([len(self.state_manager.get_action_space())], dtype=np.float32)
# do we need child priors as in the minigo code
if q_init is None:
self.q_init = np.zeros([len(self.state_manager.get_action_space())], dtype=np.float32)
elif hasattr(q_init, "__len__"):
self.q_init = np.array(q_init).astype("float32")
else:
# single float in
self.q_init = np.zeros([len(self.state_manager.get_action_space())], dtype=np.float32)
self.q_init += q_init
self.children = {} # move map to resulting node
def __repr__(self):
if self.single_player:
return "<MCTSNode move={}, N={}>".format(self.action_to_here, self.N)
else:
this_player = self.state_manager.get_current_player(self.current_state)
return "<MCTSNode move={}, N={}, to_play={}>".format(self.action_to_here, self.N, this_player)
def is_leaf(self):
return self.children == {}
@property
def child_Q(self):
child_N_nonzeros = np.where(self.child_N != 0.)[0]
# start out_buffer with q_init values
out_buffer = self.q_init
out_buffer[child_N_nonzeros] = self.child_W[child_N_nonzeros] / self.child_N[child_N_nonzeros]
return out_buffer
@property
def child_U(self):
return self.c_puct * np.sqrt(self.N) / (self.child_N + 1)
@property
def child_action_score(self):
return (self.child_Q + self.child_U)
@property
def N(self):
return self.parent.child_N[self.action_to_here]
@N.setter
def N(self, value):
self.parent.child_N[self.action_to_here] = value
@property
def W(self):
return self.parent.child_W[self.action_to_here]
@W.setter
def W(self, value):
self.parent.child_W[self.action_to_here] = value
def get_best(self, random_tiebreak=True):
scores = self.child_action_score
valid_actions = [pos for pos in np.argsort(scores)[::-1] if pos in self.children.keys()]
valid_scores = scores[valid_actions]
if random_tiebreak:
# random tiebreaker
max_score = max(valid_scores)
assert len(valid_actions) == len(valid_scores)
equivalent_valid_scores = [(vs, va) for (vs, va) in zip(valid_scores, valid_actions) if vs == max_score]
pair = random_state.choice(np.arange(len(equivalent_valid_scores)))
v_a = equivalent_valid_scores[pair][1]
child = self.children[v_a]
else:
v_a = valid_actions[0]
child = self.children[v_a]
return v_a, child
def add_virtual_loss(self, up_to):
"""propagate virtual loss upward (to root)
up_to, node to propagate until (track this)
"""
self.losses_applied += 1
if self.single_player:
# 1 player
# player == state related?
loss = 1
self.W += loss
if self.parent is None or self is up_to:
return
self.parent.add_virtual_loss(up_to)
else:
# 2 player
this_player = self.state_manager.get_current_player(self.current_state)
# use this to get alternating in 2 player
loss = -1 if this_player == 0 else 1
self.W += loss
if self.parent is None or self is up_to:
return
self.parent.add_virtual_loss(up_to)
def revert_virtual_loss(self, up_to):
"""undo virtual losses
up_to, node that was propagated until
"""
self.losses_applied -= 1
if self.single_player:
# 1 player
loss = 1
revert = -1 * loss
self.W += revert
if self.parent is None or self is up_to:
return
self.parent.revert_virtual_loss(up_to)
else:
# 2 player
this_player = self.state_manager.get_current_player(self.current_state)
# use this to get alternating in 2 player
loss = -1 if this_player == 0 else 1
revert = -1 * loss
self.W += revert
if self.parent is None or self is up_to:
return
self.parent.revert_virtual_loss(up_to)
def multileaf_safe_backup_value(self, value, up_to):
# if for some reason we selected a leaf multiple times despite
# virtual loss, don't re-run it
if self.is_expanded:
return
self.is_expanded = True
self.backup_value(value, up_to=up_to)
def backup_value(self, value, up_to):
if self.parent != None:
if not (self is up_to):
if self.single_player:
self.parent.backup_value(value, up_to)
else:
self.parent.backup_value(-value, up_to)
self.N += 1
self.W += value
def maybe_add_children(self, actions_and_probs):
for elem in actions_and_probs:
self.maybe_add_child(elem)
def maybe_add_child(self, action_and_prob):
action = action_and_prob[0]
prob = action_and_prob[1]
if action not in self.children:
# need the state itself
state = self.current_state
next_state = self.state_manager.get_next_state(state, action)
seed = self.random_state.randint(0, 1E6)
rs = np.random.RandomState(seed)
if self.single_player:
# this will come from a neural network at some point
q_init = self.child_Q
else:
# assuming 2 player, 0 sum
q_init = 0.
self.children[action] = MCTSNode(self.state_manager, single_player=self.single_player, q_init=q_init, current_state=next_state, action_to_here=action, parent=self, random_state=rs)
return self.children[action]
class MCTSPlayer(object):
def __init__(self, state_manager, single_player=False, n_readouts=100, random_state=None):
self.state_manager = state_manager
self.root = MCTSNode(state_manager, parent=None, random_state=random_state)
self.single_player = single_player
if self.single_player:
# will come from NN eventually, for now set simply
# first state has default 0 value...
q_init = 0. * np.array(self.state_manager.get_valid_actions(self.state_manager.get_init_state()))
else:
q_init = 0.
self.root = MCTSNode(state_manager, single_player=self.single_player, q_init=q_init, parent=None, random_state=random_state)
self.n_readouts = n_readouts
def select_leaf(self):
node = self.root
state = self.root.current_state
while True:
# if node has never been expanded, don't select a child
if node.is_leaf():
break
# this will need to do model evaluation
# can we cache the evaluations?
action, node = node.get_best()
state = node.state_manager.get_next_state(state, action)
return node
def tree_search(self):
# single leaf, non-parallel tree search
# useful sanity check / test case
node = self.select_leaf()
state = node.current_state
winner, end = node.state_manager.is_finished(state)
if not end:
actions = node.state_manager.get_valid_actions(state)
action_space = node.state_manager.get_action_space()
# uniform prior probs, zero out invalid actions
probs = np.zeros((len(action_space),))
probs[actions] = np.ones((len(actions))) / float(len(actions))
# only include valid actions to the child nodes!
actions_and_probs = list(zip(actions, probs[actions]))
node.maybe_add_children(actions_and_probs)
# random rollout
value = node.state_manager.rollout_from_state(state)
node.backup_value(value, up_to=self.root)
return None
def parallel_tree_search(self, parallel_readouts):
assert parallel_readouts > 0
leaves = []
failsafe = 0
while len(leaves) < parallel_readouts and failsafe < (2 * parallel_readouts):
failsafe += 1
# leaf selection will be NN based eventually
leaf = self.select_leaf()
winner, end = leaf.state_manager.is_finished(leaf.current_state)
if end:
# value will be predicted from NN
value = leaf.state_manager.rollout_from_state(leaf.current_state)
leaf.backup_value(value, up_to=self.root)
# go back to getting leaves
continue
leaf.add_virtual_loss(up_to=self.root)
leaves.append(leaf)
if leaves:
# get move probs and values
# in original code, network predicted
# move_probs, values = self.network.run_many(
# [leaf.position for leaf in leaves])
action_space = self.root.state_manager.get_action_space()
move_probs = [np.zeros((len(action_space),)) for l in leaves]
values = []
for n, leaf in enumerate(leaves):
# handle per-position invalid moves
actions = leaf.state_manager.get_valid_actions(leaf.current_state)
# uniform probs (probs not currently used)
move_probs[n][actions] = np.ones((len(actions))) / float(len(actions))
# value will be predicted from NN
value = leaf.state_manager.rollout_from_state(leaf.current_state)
# only include valid actions in the child nodes
actions_and_probs = list(zip(actions, move_probs[n][actions]))
leaf.maybe_add_children(actions_and_probs)
values.append(value)
for leaf, move_prob, value in zip(leaves, move_probs, values):
leaf.revert_virtual_loss(up_to=self.root)
leaf.multileaf_safe_backup_value(value, up_to=self.root)
return None
def simulate(self, parallel=0):
# not true parallelism yet, but batch evaluation should allow
# parallelism to be possible
assert parallel >= 0
current_readouts = self.root.N
while self.root.N < current_readouts + self.n_readouts:
if parallel == 0:
self.tree_search()
else:
self.parallel_tree_search(parallel_readouts=parallel)
return None
def pick_move(self):
# pick a move from the set of the best
# currently based purely on visit count
illegal_moves = np.where(self.root.illegal_moves == 1.)[0]
argsort_best = np.argsort(self.root.child_N)[::-1]
valid_actions = [a for a in argsort_best if a not in illegal_moves]
return valid_actions[0]
def play_move(self, move):
self.root = self.root.maybe_add_child((move, 1.))
return True
rollout_limit = 1000
# size more than 5 becomes very difficult
size = 5
n_readouts = 1000
random_state = np.random.RandomState(1234)
sm = CountingManager(size=size, rollout_limit=rollout_limit)
mcts = MCTSPlayer(state_manager=sm, single_player=True, n_readouts=n_readouts, random_state=random_state)
# play the game out
while True:
print("Current state:", mcts.root.current_state)
print("")
mcts.simulate(parallel=5)
move = mcts.pick_move()
mcts.play_move(move)
score, finished = mcts.root.state_manager.is_finished(mcts.root.current_state)
if finished:
print(mcts.root.current_state)
print("Game over")
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment