Skip to content

Instantly share code, notes, and snippets.

@masouduut94
Created October 23, 2020 17:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masouduut94/fabd0f4a2e8cc000ffc0736bca6016d2 to your computer and use it in GitHub Desktop.
Save masouduut94/fabd0f4a2e8cc000ffc0736bca6016d2 to your computer and use it in GitHub Desktop.
Wrap up the whole code.
class UctMctsAgent:
"""
Basic no frills implementation of an agent that preforms MCTS for hex.
Attributes:
root_state (GameState): Game simulator that helps us to understand the game situation
root (Node): Root of the tree search
run_time (int): time per each run
node_count (int): the whole nodes in tree
num_rollouts (int): The number of rollouts for each search
"""
def __init__(self, state=GameState(8)):
self.root_state = deepcopy(state)
self.root = Node()
self.run_time = 0
self.node_count = 0
self.num_rollouts = 0
def search(self, time_budget: int) -> None:
"""
Search and update the search tree for a
specified amount of time in seconds.
"""
start_time = clock()
num_rollouts = 0
# do until we exceed our time budget
while clock() - start_time < time_budget:
node, state = self.select_node()
turn = state.turn()
outcome = self.roll_out(state)
self.backup(node, turn, outcome)
num_rollouts += 1
run_time = clock() - start_time
node_count = self.tree_size()
self.run_time = run_time
self.node_count = node_count
self.num_rollouts = num_rollouts
def select_node(self) -> tuple:
"""
Select a node in the tree to preform a single simulation from.
"""
node = self.root
state = deepcopy(self.root_state)
# stop if we find reach a leaf node
while len(node.children) != 0:
# descend to the maximum value node, break ties at random
children = node.children.values()
max_value = max(children, key=lambda n: n.value).value
max_nodes = [n for n in node.children.values()
if n.value == max_value]
node = choice(max_nodes)
state.play(node.move)
# if some child node has not been explored select it before expanding
# other children
if node.N == 0:
return node, state
# if we reach a leaf node generate its children and return one of them
# if the node is terminal, just return the terminal node
if self.expand(node, state):
node = choice(list(node.children.values()))
state.play(node.move)
return node, state
@staticmethod
def expand(parent: Node, state: GameState) -> bool:
"""
Generate the children of the passed "parent" node based on the available
moves in the passed gamestate and add them to the tree.
Returns:
bool: returns false If node is leaf (the game has ended).
"""
children = []
if state.winner != GameMeta.PLAYERS['none']:
# game is over at this node so nothing to expand
return False
for move in state.moves():
children.append(Node(move, parent))
parent.add_children(children)
return True
@staticmethod
def roll_out(state: GameState) -> int:
"""
Simulate an entirely random game from the passed state and return the winning
player.
Args:
state: game state
Returns:
int: winner of the game
"""
moves = state.moves() # Get a list of all possible moves in current state of the game
while state.winner == GameMeta.PLAYERS['none']:
move = choice(moves)
state.play(move)
moves.remove(move)
return state.winner
@staticmethod
def backup(node: Node, turn: int, outcome: int) -> None:
"""
Update the node statistics on the path from the passed node to root to reflect
the outcome of a randomly simulated playout.
Args:
node:
turn: winner turn
outcome: outcome of the rollout
Returns:
object:
"""
# Careful: The reward is calculated for player who just played
# at the node and not the next player to play
reward = 0 if outcome == turn else 1
while node is not None:
node.N += 1
node.Q += reward
node = node.parent
reward = 0 if reward == 1 else 1
def best_move(self) -> tuple:
"""
Return the best move according to the current tree.
Returns:
best move in terms of the most simulations number unless the game is over
"""
if self.root_state.winner != GameMeta.PLAYERS['none']:
return GameMeta.GAME_OVER
# choose the move of the most simulated node breaking ties randomly
max_value = max(self.root.children.values(), key=lambda n: n.N).N
max_nodes = [n for n in self.root.children.values() if n.N == max_value]
bestchild = choice(max_nodes)
return bestchild.move
def move(self, move: tuple) -> None:
"""
Make the passed move and update the tree appropriately. It is
designed to let the player choose an action manually (which might
not be the best action).
Args:
move:
"""
if move in self.root.children:
child = self.root.children[move]
child.parent = None
self.root = child
self.root_state.play(child.move)
return
# if for whatever reason the move is not in the children of
# the root just throw out the tree and start over
self.root_state.play(move)
self.root = Node()
def set_gamestate(self, state: GameState) -> None:
"""
Set the root_state of the tree to the passed gamestate, this clears all
the information stored in the tree since none of it applies to the new
state.
"""
self.root_state = deepcopy(state)
self.root = Node()
def statistics(self) -> tuple:
return self.num_rollouts, self.node_count, self.run_time
def tree_size(self) -> int:
"""
Count nodes in tree by BFS.
"""
Q = Queue()
count = 0
Q.put(self.root)
while not Q.empty():
node = Q.get()
count += 1
for child in node.children.values():
Q.put(child)
return count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment