Created
October 23, 2020 17:19
-
-
Save masouduut94/fabd0f4a2e8cc000ffc0736bca6016d2 to your computer and use it in GitHub Desktop.
Wrap up the whole code.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UctMctsAgent: | |
""" | |
Basic no frills implementation of an agent that preforms MCTS for hex. | |
Attributes: | |
root_state (GameState): Game simulator that helps us to understand the game situation | |
root (Node): Root of the tree search | |
run_time (int): time per each run | |
node_count (int): the whole nodes in tree | |
num_rollouts (int): The number of rollouts for each search | |
""" | |
def __init__(self, state=GameState(8)): | |
self.root_state = deepcopy(state) | |
self.root = Node() | |
self.run_time = 0 | |
self.node_count = 0 | |
self.num_rollouts = 0 | |
def search(self, time_budget: int) -> None: | |
""" | |
Search and update the search tree for a | |
specified amount of time in seconds. | |
""" | |
start_time = clock() | |
num_rollouts = 0 | |
# do until we exceed our time budget | |
while clock() - start_time < time_budget: | |
node, state = self.select_node() | |
turn = state.turn() | |
outcome = self.roll_out(state) | |
self.backup(node, turn, outcome) | |
num_rollouts += 1 | |
run_time = clock() - start_time | |
node_count = self.tree_size() | |
self.run_time = run_time | |
self.node_count = node_count | |
self.num_rollouts = num_rollouts | |
def select_node(self) -> tuple: | |
""" | |
Select a node in the tree to preform a single simulation from. | |
""" | |
node = self.root | |
state = deepcopy(self.root_state) | |
# stop if we find reach a leaf node | |
while len(node.children) != 0: | |
# descend to the maximum value node, break ties at random | |
children = node.children.values() | |
max_value = max(children, key=lambda n: n.value).value | |
max_nodes = [n for n in node.children.values() | |
if n.value == max_value] | |
node = choice(max_nodes) | |
state.play(node.move) | |
# if some child node has not been explored select it before expanding | |
# other children | |
if node.N == 0: | |
return node, state | |
# if we reach a leaf node generate its children and return one of them | |
# if the node is terminal, just return the terminal node | |
if self.expand(node, state): | |
node = choice(list(node.children.values())) | |
state.play(node.move) | |
return node, state | |
@staticmethod | |
def expand(parent: Node, state: GameState) -> bool: | |
""" | |
Generate the children of the passed "parent" node based on the available | |
moves in the passed gamestate and add them to the tree. | |
Returns: | |
bool: returns false If node is leaf (the game has ended). | |
""" | |
children = [] | |
if state.winner != GameMeta.PLAYERS['none']: | |
# game is over at this node so nothing to expand | |
return False | |
for move in state.moves(): | |
children.append(Node(move, parent)) | |
parent.add_children(children) | |
return True | |
@staticmethod | |
def roll_out(state: GameState) -> int: | |
""" | |
Simulate an entirely random game from the passed state and return the winning | |
player. | |
Args: | |
state: game state | |
Returns: | |
int: winner of the game | |
""" | |
moves = state.moves() # Get a list of all possible moves in current state of the game | |
while state.winner == GameMeta.PLAYERS['none']: | |
move = choice(moves) | |
state.play(move) | |
moves.remove(move) | |
return state.winner | |
@staticmethod | |
def backup(node: Node, turn: int, outcome: int) -> None: | |
""" | |
Update the node statistics on the path from the passed node to root to reflect | |
the outcome of a randomly simulated playout. | |
Args: | |
node: | |
turn: winner turn | |
outcome: outcome of the rollout | |
Returns: | |
object: | |
""" | |
# Careful: The reward is calculated for player who just played | |
# at the node and not the next player to play | |
reward = 0 if outcome == turn else 1 | |
while node is not None: | |
node.N += 1 | |
node.Q += reward | |
node = node.parent | |
reward = 0 if reward == 1 else 1 | |
def best_move(self) -> tuple: | |
""" | |
Return the best move according to the current tree. | |
Returns: | |
best move in terms of the most simulations number unless the game is over | |
""" | |
if self.root_state.winner != GameMeta.PLAYERS['none']: | |
return GameMeta.GAME_OVER | |
# choose the move of the most simulated node breaking ties randomly | |
max_value = max(self.root.children.values(), key=lambda n: n.N).N | |
max_nodes = [n for n in self.root.children.values() if n.N == max_value] | |
bestchild = choice(max_nodes) | |
return bestchild.move | |
def move(self, move: tuple) -> None: | |
""" | |
Make the passed move and update the tree appropriately. It is | |
designed to let the player choose an action manually (which might | |
not be the best action). | |
Args: | |
move: | |
""" | |
if move in self.root.children: | |
child = self.root.children[move] | |
child.parent = None | |
self.root = child | |
self.root_state.play(child.move) | |
return | |
# if for whatever reason the move is not in the children of | |
# the root just throw out the tree and start over | |
self.root_state.play(move) | |
self.root = Node() | |
def set_gamestate(self, state: GameState) -> None: | |
""" | |
Set the root_state of the tree to the passed gamestate, this clears all | |
the information stored in the tree since none of it applies to the new | |
state. | |
""" | |
self.root_state = deepcopy(state) | |
self.root = Node() | |
def statistics(self) -> tuple: | |
return self.num_rollouts, self.node_count, self.run_time | |
def tree_size(self) -> int: | |
""" | |
Count nodes in tree by BFS. | |
""" | |
Q = Queue() | |
count = 0 | |
Q.put(self.root) | |
while not Q.empty(): | |
node = Q.get() | |
count += 1 | |
for child in node.children.values(): | |
Q.put(child) | |
return count |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment