Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Hamptonjc/eaddb79518c3b25c08fbe055b6a3c2e3 to your computer and use it in GitHub Desktop.
Save Hamptonjc/eaddb79518c3b25c08fbe055b6a3c2e3 to your computer and use it in GitHub Desktop.
Monte-Carlo-Tree-Search-TicTacToe.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Monte-Carlo-Tree-Search-TicTacToe.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyMbIykFOYkaFsvrPxjsvo9k",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Hamptonjc/eaddb79518c3b25c08fbe055b6a3c2e3/monte-carlo-tree-search-tictactoe.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fPx3NuaD__b4"
},
"source": [
"# Monte Carlo Tree Search With Tic-Tac-Toe"
]
},
{
"cell_type": "code",
"metadata": {
"id": "REy5l3Z6_7s8"
},
"source": [
"# Imports\n",
"import numpy as np\n",
"from typing import Tuple, List, Union\n",
"import math"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PseP1ZFpAKHj"
},
"source": [
"## Game"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6E8_o3ZPAJBQ"
},
"source": [
"class TicTacToe:\n",
" \n",
" def __init__(self) -> None:\n",
" self.new_game()\n",
" \n",
" def new_game(self) -> None:\n",
" self.board = np.ones((3,3)) * -1\n",
" self.player_turn = 1\n",
" print(\"New game!\\nCurrent board:\\n\", self.board, \"\\n\")\n",
" \n",
" \n",
" def make_move(self, coord: Tuple[int, int]):\n",
" x, y = coord\n",
" if self.board[x,y] == 0 or self.board[x,y] == 1:\n",
" print(\"That space is already occupied! Try again...\")\n",
" return\n",
" self.board[x,y] = 1\n",
" if self.check_for_win(1):\n",
" print(\"You win :-)\\n\", self.board, \"\\n\")\n",
" return\n",
" \n",
" if self.check_for_draw():\n",
" print(\"Its a draw!\\n\", self.board, \"\\n\")\n",
" return\n",
" \n",
" print(\"computer's turn...\", \"\\n\")\n",
" self.player_turn = 0\n",
" self.computer_move(pawn=0, random_play=False)\n",
"\n",
" \n",
" def computer_move(self, pawn: int, random_play: bool=True,\n",
" verbose: bool=True, check_for_result: bool=True) -> None:\n",
" if random_play:\n",
" available_spaces = np.argwhere(self.board == -1)\n",
" random_idx = np.random.choice(len(available_spaces), size=1, replace=False)\n",
" x, y = available_spaces[random_idx, :][0]\n",
" self.board[x,y] = pawn\n",
" else:\n",
" mcts = MCTS(self)\n",
" action = mcts.iterate(500000)\n",
" x, y = action\n",
" self.board[x,y] = pawn\n",
" \n",
" self.player_turn = 1 if pawn == 0 else 0\n",
" \n",
" if check_for_result:\n",
" if self.check_for_win(pawn):\n",
" if verbose:\n",
" print(\"Computer wins :-(\\n\", self.board, \"\\n\")\n",
" return\n",
" if self.check_for_draw():\n",
" if verbose:\n",
" print(\"Its a draw!\\n\", self.board, \"\\n\")\n",
" return \n",
" if verbose: \n",
" print(f\"Computer played ({x},{y})\", \"\\n\")\n",
" print(\"Current board:\\n\", self.board, \"\\n\")\n",
" \n",
" \n",
" def check_for_win(self, player: int) -> bool:\n",
" # horizontal win\n",
" for row in self.board:\n",
" if all(i == player for i in row):\n",
" return True\n",
" # vertical win\n",
" for col in self.board.T:\n",
" if all(i == player for i in col):\n",
" return True\n",
" # Diagonal win\n",
" # l2r\n",
" if all(i == player for i in self.board.diagonal()):\n",
" return True\n",
" #r2l\n",
" if all(i == player for i in np.fliplr(self.board).diagonal()):\n",
" return True\n",
" # Else No win\n",
" return False\n",
" \n",
" \n",
" def check_for_draw(self) -> bool:\n",
" return all(i > -1 for i in self.board.flatten())\n",
" \n",
" \n",
" def simulate_play(self, board_state: np.array, players_turn: int, verbose: bool=True) -> int:\n",
" saved_board = self.board\n",
" self.board = board_state.copy()\n",
" p0_win_value = 10\n",
" p1_win_value = -10\n",
" draw_value = 0\n",
" \n",
" # check board state for a win or draw\n",
" if self.check_for_win(players_turn):\n",
" self.board = saved_board\n",
" if players_turn == 1:\n",
" return p1_win_value\n",
" else:\n",
" return p0_win_value\n",
" elif self.check_for_win((0 if players_turn==1 else 1)):\n",
" self.board = saved_board\n",
" if players_turn == 1:\n",
" return p0_win_value\n",
" else:\n",
" return p1_win_value\n",
" elif self.check_for_draw():\n",
" self.board = saved_board\n",
" return draw_value\n",
" \n",
" # Simulate\n",
" while True:\n",
" self.computer_move(players_turn, verbose=False, check_for_result=False)\n",
" \n",
" if self.check_for_win(players_turn):\n",
" if verbose:\n",
" print(\"p1 wins:\\n\", self.board)\n",
" self.board = saved_board\n",
" if players_turn == 1:\n",
" return p1_win_value\n",
" else:\n",
" return p0_win_value\n",
" \n",
" elif self.check_for_draw():\n",
" if verbose:\n",
" print(\"draw:\\n\", self.board)\n",
" self.board = saved_board\n",
" return draw_value\n",
" \n",
" self.computer_move((0 if players_turn==1 else 1), verbose=False, check_for_result=False)\n",
" \n",
" if self.check_for_win((0 if players_turn==1 else 1)):\n",
" if verbose:\n",
" print(\"p0 wins:\\n\", self.board)\n",
" self.board = saved_board\n",
" if players_turn == 1:\n",
" return p0_win_value\n",
" else:\n",
" return p1_win_value\n",
" \n",
" elif self.check_for_draw():\n",
" if verbose:\n",
" print(\"draw:\\n\", self.board)\n",
" self.board = saved_board\n",
" return draw_value\n",
" "
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "fvuwY4kPARtS"
},
"source": [
"## Monte Carlo Tree Search"
]
},
{
"cell_type": "code",
"metadata": {
"id": "t0P8rVrQAQ3D"
},
"source": [
"\n",
"class Node:\n",
" def __init__(self, state: np.array, ucb: float, value: int, visits: int,\n",
" leaves: List['Node'], parent: Union['Node', None], player_turn: int) -> None:\n",
" self.state = state\n",
" self.value = value\n",
" self.visits = visits\n",
" self.leaves = leaves\n",
" self.parent = parent\n",
" self.player_turn = player_turn\n",
"\n",
"\n",
"class MCTS:\n",
"\n",
" def __init__(self, game: TicTacToe) -> None:\n",
" self.game = game\n",
" \n",
" \n",
" def rollout(self, state: np.array, players_turn: int) -> int:\n",
" return self.game.simulate_play(state.copy(), players_turn, verbose=False)\n",
" \n",
" def iterate(self, n_iterations: int) -> Tuple[int, int]:\n",
" # New MCTS session\n",
" self.tree = Node(state=self.game.board, ucb=0, value=0, visits=1,\n",
" leaves=[], parent=None, player_turn=self.game.player_turn)\n",
"\n",
" # iterate for a set number of iterations\n",
" for _ in range(n_iterations):\n",
" # assign current node to the root\n",
" current_node = self.tree\n",
" # if current node is not a leaf...\n",
" while len(current_node.leaves) != 0:\n",
" # select an action based on UCB\n",
" current_node = self.selection(current_node)\n",
" # if this is the first visit to the current node...\n",
" if current_node.visits == 0:\n",
" # perform rollout\n",
" value = self.rollout(current_node.state, current_node.player_turn)\n",
" # backpropogate the results\n",
" self.backprop(value, current_node)\n",
" # else if this isn't the first visit to the current node...\n",
" else:\n",
" # Expand the current node\n",
" self.expansion(current_node)\n",
" # update the current node to the first new child node\n",
" try:\n",
" current_node = current_node.leaves[0]\n",
" except:\n",
" pass\n",
" # Perform rollout from new node\n",
" value = self.rollout(current_node.state, current_node.player_turn)\n",
" # backpropogate the results\n",
" self.backprop(value, current_node)\n",
" # Find the best next state from root\n",
" best_state = self.tree.leaves[0]\n",
" best_state_value = best_state.value\n",
" for state in self.tree.leaves[1:]:\n",
" state_value = state.value\n",
" if state_value > best_state_value:\n",
" best_state = state\n",
" best_state_value = state_value\n",
" # derive the action to get to the best state\n",
" for idx, (root, child) in enumerate(zip(self.tree.state.flatten(), best_state.state.flatten())):\n",
" if root != child:\n",
" action = np.unravel_index(idx, (3,3))\n",
" return action\n",
" \n",
" \n",
" def selection(self, current_node: Node) -> Node:\n",
" new_node = current_node.leaves[0]\n",
" new_node_ucb = self.calculate_UCB(new_node)\n",
" for n in current_node.leaves:\n",
" n_ucb = self.calculate_UCB(n)\n",
" if n_ucb > new_node_ucb:\n",
" new_node = n\n",
" new_node_ucb = n_ucb\n",
" return new_node\n",
" \n",
" \n",
" def expansion(self, current_node: Node) -> Node:\n",
" board = current_node.state\n",
" next_player = 0 if current_node.player_turn == 1 else 1\n",
" available_spaces = np.argwhere(board == -1)\n",
" for space in available_spaces:\n",
" x, y = space\n",
" state = board.copy()\n",
" state[x,y] = current_node.player_turn\n",
" state = state.reshape((3,3))\n",
" current_node.leaves.append(Node(state=state, ucb=np.inf, value=0,\n",
" visits=0, leaves=[], parent=current_node,\n",
" player_turn=next_player))\n",
" \n",
" def backprop(self, sim_result: int, current_node: Node):\n",
" current_node.value += sim_result\n",
" current_node.visits += 1\n",
" node = current_node.parent\n",
" while node is not None:\n",
" node.value += sim_result\n",
" node.visits += 1\n",
" node = node.parent\n",
" \n",
" \n",
" def calculate_UCB(self, node: Node) -> float:\n",
" value = node.value\n",
" if node.visits == 0:\n",
" return math.inf\n",
" else:\n",
" return value + 2*math.sqrt((math.log(node.parent.visits))/node.visits)\n",
" \n",
" def check_if_terminal(self, node: Node) -> bool:\n",
" if len(np.argwhere(node.state == -1)) == 0:\n",
" return True\n",
" else:\n",
" return False\n",
" \n"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "0IdMpN95AeGo"
},
"source": [
"## Play"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J8bPTXmfAdE7",
"outputId": "8825399e-aee5-4fcb-b255-112c5b8080a4"
},
"source": [
"game = TicTacToe()"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"New game!\n",
"Current board:\n",
" [[-1. -1. -1.]\n",
" [-1. -1. -1.]\n",
" [-1. -1. -1.]] \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Sp8qUHBkAZgi",
"outputId": "ba890e84-d00a-42fc-9263-d6ce741f0c5c"
},
"source": [
"game.make_move((2,2))"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"computer's turn... \n",
"\n",
"Computer played (1,1) \n",
"\n",
"Current board:\n",
" [[-1. -1. -1.]\n",
" [-1. 0. -1.]\n",
" [-1. -1. 1.]] \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aQ2clnG7Ak0i",
"outputId": "751eb840-6ff0-4405-b57b-08499579bb48"
},
"source": [
"game.make_move((2,1))"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"computer's turn... \n",
"\n",
"Computer played (2,0) \n",
"\n",
"Current board:\n",
" [[-1. -1. -1.]\n",
" [-1. 0. -1.]\n",
" [ 0. 1. 1.]] \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "m0nRlgvaAyFF",
"outputId": "1176d0af-ae22-4f51-afbc-d4a8588d8f43"
},
"source": [
"game.make_move((1,2))"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"computer's turn... \n",
"\n",
"Computer wins :-(\n",
" [[-1. -1. 0.]\n",
" [-1. 0. 1.]\n",
" [ 0. 1. 1.]] \n",
"\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment