👨👩👦👦
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UctMctsAgent: | |
""" | |
Basic no frills implementation of an agent that preforms MCTS for hex. | |
Attributes: | |
root_state (GameState): Game simulator that helps us to understand the game situation | |
root (Node): Root of the tree search | |
run_time (int): time per each run | |
node_count (int): the whole nodes in tree | |
num_rollouts (int): The number of rollouts for each search | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def search(self, time_budget: int) -> None: | |
""" | |
Search and update the search tree for a | |
specified amount of time in seconds. | |
""" | |
start_time = clock() | |
num_rollouts = 0 | |
# do until we exceed our time budget |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@staticmethod | |
def backup(node: Node, turn: int, outcome: int) -> None: | |
""" | |
Update the node statistics on the path from the passed node to root to reflect | |
the outcome of a randomly simulated playout. | |
Args: | |
node: | |
turn: winner turn | |
outcome: outcome of the rollout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@staticmethod | |
def roll_out(state: GameState) -> int: | |
""" | |
Simulate an entirely random game from the passed state and return the winning | |
player. | |
Args: | |
state: game state | |
Returns: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def select_node(self) -> tuple: | |
""" | |
Select a node in the tree to preform a single simulation from. | |
""" | |
node = self.root | |
state = deepcopy(self.root_state) | |
# stop if we find reach a leaf node | |
while len(node.children) != 0: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UctMctsAgent: | |
""" | |
Basic no frills implementation of an agent that preforms MCTS for hex. | |
Attributes: | |
root_state (GameState): Game simulator that helps us to | |
understand the game situation. | |
root (Node): Root of the tree search. | |
run_time (int): time per each run. | |
node_count (int): the whole nodes in tree. | |
num_rollouts (int): The number of rollouts for each search. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node: | |
""" | |
Node for the MCTS. Stores the move applied to reach this node from its parent, | |
stats for the associated game position, children, parent and outcome | |
(outcome==none unless the position ends the game). | |
Args: | |
move: | |
parent: | |
N (int): times this position was visited. | |
Q (int): average reward (wins-losses) from this position. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GameMeta: | |
PLAYERS = {'none': 0, 'white': 1, 'black': 2} | |
INF = float('inf') | |
GAME_OVER = -1 | |
EDGE1 = 1 | |
EDGE2 = 2 | |
NEIGHBOR_PATTERNS = ((-1, 0), (0, -1), (-1, 1), (0, 1), (1, 0), (1, -1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import zeros, int_ | |
from unionfind import UnionFind | |
from meta import GameMeta | |
class GameState: | |
""" | |
Stores information representing the current state of a game of hex, namely | |
the board and the current turn. Also provides functions for playing game. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UnionFind: | |
""" | |
Notes: | |
unionfind data structure specialized for finding hex connections. | |
Implementation inspired by UAlberta CMPUT 275 2015 class notes. | |
Attributes: | |
parent (dict): Each group parent | |
rank (dict): Each group rank | |
groups (dict): Stores the groups and chain of cells |