Created
June 12, 2021 20:58
-
-
Save Hamptonjc/eaddb79518c3b25c08fbe055b6a3c2e3 to your computer and use it in GitHub Desktop.
Monte-Carlo-Tree-Search-TicTacToe.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Monte-Carlo-Tree-Search-TicTacToe.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyMbIykFOYkaFsvrPxjsvo9k", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Hamptonjc/eaddb79518c3b25c08fbe055b6a3c2e3/monte-carlo-tree-search-tictactoe.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "fPx3NuaD__b4" | |
}, | |
"source": [ | |
"# Monte Carlo Tree Search With Tic-Tac-Toe" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "REy5l3Z6_7s8" | |
}, | |
"source": [ | |
"# Imports\n", | |
"import numpy as np\n", | |
"from typing import Tuple, List, Union\n", | |
"import math" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "PseP1ZFpAKHj" | |
}, | |
"source": [ | |
"## Game" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6E8_o3ZPAJBQ" | |
}, | |
"source": [ | |
"class TicTacToe:\n", | |
" \n", | |
" def __init__(self) -> None:\n", | |
" self.new_game()\n", | |
" \n", | |
" def new_game(self) -> None:\n", | |
" self.board = np.ones((3,3)) * -1\n", | |
" self.player_turn = 1\n", | |
" print(\"New game!\\nCurrent board:\\n\", self.board, \"\\n\")\n", | |
" \n", | |
" \n", | |
" def make_move(self, coord: Tuple[int, int]):\n", | |
" x, y = coord\n", | |
" if self.board[x,y] == 0 or self.board[x,y] == 1:\n", | |
" print(\"That space is already occupied! Try again...\")\n", | |
" return\n", | |
" self.board[x,y] = 1\n", | |
" if self.check_for_win(1):\n", | |
" print(\"You win :-)\\n\", self.board, \"\\n\")\n", | |
" return\n", | |
" \n", | |
" if self.check_for_draw():\n", | |
" print(\"Its a draw!\\n\", self.board, \"\\n\")\n", | |
" return\n", | |
" \n", | |
" print(\"computer's turn...\", \"\\n\")\n", | |
" self.player_turn = 0\n", | |
" self.computer_move(pawn=0, random_play=False)\n", | |
"\n", | |
" \n", | |
" def computer_move(self, pawn: int, random_play: bool=True,\n", | |
" verbose: bool=True, check_for_result: bool=True) -> None:\n", | |
" if random_play:\n", | |
" available_spaces = np.argwhere(self.board == -1)\n", | |
" random_idx = np.random.choice(len(available_spaces), size=1, replace=False)\n", | |
" x, y = available_spaces[random_idx, :][0]\n", | |
" self.board[x,y] = pawn\n", | |
" else:\n", | |
" mcts = MCTS(self)\n", | |
" action = mcts.iterate(500000)\n", | |
" x, y = action\n", | |
" self.board[x,y] = pawn\n", | |
" \n", | |
" self.player_turn = 1 if pawn == 0 else 0\n", | |
" \n", | |
" if check_for_result:\n", | |
" if self.check_for_win(pawn):\n", | |
" if verbose:\n", | |
" print(\"Computer wins :-(\\n\", self.board, \"\\n\")\n", | |
" return\n", | |
" if self.check_for_draw():\n", | |
" if verbose:\n", | |
" print(\"Its a draw!\\n\", self.board, \"\\n\")\n", | |
" return \n", | |
" if verbose: \n", | |
" print(f\"Computer played ({x},{y})\", \"\\n\")\n", | |
" print(\"Current board:\\n\", self.board, \"\\n\")\n", | |
" \n", | |
" \n", | |
" def check_for_win(self, player: int) -> bool:\n", | |
" # horizontal win\n", | |
" for row in self.board:\n", | |
" if all(i == player for i in row):\n", | |
" return True\n", | |
" # vertical win\n", | |
" for col in self.board.T:\n", | |
" if all(i == player for i in col):\n", | |
" return True\n", | |
" # Diagonal win\n", | |
" # l2r\n", | |
" if all(i == player for i in self.board.diagonal()):\n", | |
" return True\n", | |
" #r2l\n", | |
" if all(i == player for i in np.fliplr(self.board).diagonal()):\n", | |
" return True\n", | |
" # Else No win\n", | |
" return False\n", | |
" \n", | |
" \n", | |
" def check_for_draw(self) -> bool:\n", | |
" return all(i > -1 for i in self.board.flatten())\n", | |
" \n", | |
" \n", | |
" def simulate_play(self, board_state: np.array, players_turn: int, verbose: bool=True) -> int:\n", | |
" saved_board = self.board\n", | |
" self.board = board_state.copy()\n", | |
" p0_win_value = 10\n", | |
" p1_win_value = -10\n", | |
" draw_value = 0\n", | |
" \n", | |
" # check board state for a win or draw\n", | |
" if self.check_for_win(players_turn):\n", | |
" self.board = saved_board\n", | |
" if players_turn == 1:\n", | |
" return p1_win_value\n", | |
" else:\n", | |
" return p0_win_value\n", | |
" elif self.check_for_win((0 if players_turn==1 else 1)):\n", | |
" self.board = saved_board\n", | |
" if players_turn == 1:\n", | |
" return p0_win_value\n", | |
" else:\n", | |
" return p1_win_value\n", | |
" elif self.check_for_draw():\n", | |
" self.board = saved_board\n", | |
" return draw_value\n", | |
" \n", | |
" # Simulate\n", | |
" while True:\n", | |
" self.computer_move(players_turn, verbose=False, check_for_result=False)\n", | |
" \n", | |
" if self.check_for_win(players_turn):\n", | |
" if verbose:\n", | |
" print(\"p1 wins:\\n\", self.board)\n", | |
" self.board = saved_board\n", | |
" if players_turn == 1:\n", | |
" return p1_win_value\n", | |
" else:\n", | |
" return p0_win_value\n", | |
" \n", | |
" elif self.check_for_draw():\n", | |
" if verbose:\n", | |
" print(\"draw:\\n\", self.board)\n", | |
" self.board = saved_board\n", | |
" return draw_value\n", | |
" \n", | |
" self.computer_move((0 if players_turn==1 else 1), verbose=False, check_for_result=False)\n", | |
" \n", | |
" if self.check_for_win((0 if players_turn==1 else 1)):\n", | |
" if verbose:\n", | |
" print(\"p0 wins:\\n\", self.board)\n", | |
" self.board = saved_board\n", | |
" if players_turn == 1:\n", | |
" return p0_win_value\n", | |
" else:\n", | |
" return p1_win_value\n", | |
" \n", | |
" elif self.check_for_draw():\n", | |
" if verbose:\n", | |
" print(\"draw:\\n\", self.board)\n", | |
" self.board = saved_board\n", | |
" return draw_value\n", | |
" " | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "fvuwY4kPARtS" | |
}, | |
"source": [ | |
"## Monte Carlo Tree Search" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "t0P8rVrQAQ3D" | |
}, | |
"source": [ | |
"\n", | |
"class Node:\n", | |
" def __init__(self, state: np.array, ucb: float, value: int, visits: int,\n", | |
" leaves: List['Node'], parent: Union['Node', None], player_turn: int) -> None:\n", | |
" self.state = state\n", | |
" self.value = value\n", | |
" self.visits = visits\n", | |
" self.leaves = leaves\n", | |
" self.parent = parent\n", | |
" self.player_turn = player_turn\n", | |
"\n", | |
"\n", | |
"class MCTS:\n", | |
"\n", | |
" def __init__(self, game: TicTacToe) -> None:\n", | |
" self.game = game\n", | |
" \n", | |
" \n", | |
" def rollout(self, state: np.array, players_turn: int) -> int:\n", | |
" return self.game.simulate_play(state.copy(), players_turn, verbose=False)\n", | |
" \n", | |
" def iterate(self, n_iterations: int) -> Tuple[int, int]:\n", | |
" # New MCTS session\n", | |
" self.tree = Node(state=self.game.board, ucb=0, value=0, visits=1,\n", | |
" leaves=[], parent=None, player_turn=self.game.player_turn)\n", | |
"\n", | |
" # iterate for a set number of iterations\n", | |
" for _ in range(n_iterations):\n", | |
" # assign current node to the root\n", | |
" current_node = self.tree\n", | |
" # if current node is not a leaf...\n", | |
" while len(current_node.leaves) != 0:\n", | |
" # select an action based on UCB\n", | |
" current_node = self.selection(current_node)\n", | |
" # if this is the first visit to the current node...\n", | |
" if current_node.visits == 0:\n", | |
" # perform rollout\n", | |
" value = self.rollout(current_node.state, current_node.player_turn)\n", | |
" # backpropogate the results\n", | |
" self.backprop(value, current_node)\n", | |
" # else if this isn't the first visit to the current node...\n", | |
" else:\n", | |
" # Expand the current node\n", | |
" self.expansion(current_node)\n", | |
" # update the current node to the first new child node\n", | |
" try:\n", | |
" current_node = current_node.leaves[0]\n", | |
" except:\n", | |
" pass\n", | |
" # Perform rollout from new node\n", | |
" value = self.rollout(current_node.state, current_node.player_turn)\n", | |
" # backpropogate the results\n", | |
" self.backprop(value, current_node)\n", | |
" # Find the best next state from root\n", | |
" best_state = self.tree.leaves[0]\n", | |
" best_state_value = best_state.value\n", | |
" for state in self.tree.leaves[1:]:\n", | |
" state_value = state.value\n", | |
" if state_value > best_state_value:\n", | |
" best_state = state\n", | |
" best_state_value = state_value\n", | |
" # derive the action to get to the best state\n", | |
" for idx, (root, child) in enumerate(zip(self.tree.state.flatten(), best_state.state.flatten())):\n", | |
" if root != child:\n", | |
" action = np.unravel_index(idx, (3,3))\n", | |
" return action\n", | |
" \n", | |
" \n", | |
" def selection(self, current_node: Node) -> Node:\n", | |
" new_node = current_node.leaves[0]\n", | |
" new_node_ucb = self.calculate_UCB(new_node)\n", | |
" for n in current_node.leaves:\n", | |
" n_ucb = self.calculate_UCB(n)\n", | |
" if n_ucb > new_node_ucb:\n", | |
" new_node = n\n", | |
" new_node_ucb = n_ucb\n", | |
" return new_node\n", | |
" \n", | |
" \n", | |
" def expansion(self, current_node: Node) -> Node:\n", | |
" board = current_node.state\n", | |
" next_player = 0 if current_node.player_turn == 1 else 1\n", | |
" available_spaces = np.argwhere(board == -1)\n", | |
" for space in available_spaces:\n", | |
" x, y = space\n", | |
" state = board.copy()\n", | |
" state[x,y] = current_node.player_turn\n", | |
" state = state.reshape((3,3))\n", | |
" current_node.leaves.append(Node(state=state, ucb=np.inf, value=0,\n", | |
" visits=0, leaves=[], parent=current_node,\n", | |
" player_turn=next_player))\n", | |
" \n", | |
" def backprop(self, sim_result: int, current_node: Node):\n", | |
" current_node.value += sim_result\n", | |
" current_node.visits += 1\n", | |
" node = current_node.parent\n", | |
" while node is not None:\n", | |
" node.value += sim_result\n", | |
" node.visits += 1\n", | |
" node = node.parent\n", | |
" \n", | |
" \n", | |
" def calculate_UCB(self, node: Node) -> float:\n", | |
" value = node.value\n", | |
" if node.visits == 0:\n", | |
" return math.inf\n", | |
" else:\n", | |
" return value + 2*math.sqrt((math.log(node.parent.visits))/node.visits)\n", | |
" \n", | |
" def check_if_terminal(self, node: Node) -> bool:\n", | |
" if len(np.argwhere(node.state == -1)) == 0:\n", | |
" return True\n", | |
" else:\n", | |
" return False\n", | |
" \n" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "0IdMpN95AeGo" | |
}, | |
"source": [ | |
"## Play" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "J8bPTXmfAdE7", | |
"outputId": "8825399e-aee5-4fcb-b255-112c5b8080a4" | |
}, | |
"source": [ | |
"game = TicTacToe()" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"New game!\n", | |
"Current board:\n", | |
" [[-1. -1. -1.]\n", | |
" [-1. -1. -1.]\n", | |
" [-1. -1. -1.]] \n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Sp8qUHBkAZgi", | |
"outputId": "ba890e84-d00a-42fc-9263-d6ce741f0c5c" | |
}, | |
"source": [ | |
"game.make_move((2,2))" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"computer's turn... \n", | |
"\n", | |
"Computer played (1,1) \n", | |
"\n", | |
"Current board:\n", | |
" [[-1. -1. -1.]\n", | |
" [-1. 0. -1.]\n", | |
" [-1. -1. 1.]] \n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "aQ2clnG7Ak0i", | |
"outputId": "751eb840-6ff0-4405-b57b-08499579bb48" | |
}, | |
"source": [ | |
"game.make_move((2,1))" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"computer's turn... \n", | |
"\n", | |
"Computer played (2,0) \n", | |
"\n", | |
"Current board:\n", | |
" [[-1. -1. -1.]\n", | |
" [-1. 0. -1.]\n", | |
" [ 0. 1. 1.]] \n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "m0nRlgvaAyFF", | |
"outputId": "1176d0af-ae22-4f51-afbc-d4a8588d8f43" | |
}, | |
"source": [ | |
"game.make_move((1,2))" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"computer's turn... \n", | |
"\n", | |
"Computer wins :-(\n", | |
" [[-1. -1. 0.]\n", | |
" [-1. 0. 1.]\n", | |
" [ 0. 1. 1.]] \n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment