Last active
March 23, 2023 08:14
-
-
Save panchishin/5d50b0c3660f923883c626002ec2a4d1 to your computer and use it in GitHub Desktop.
Tic Tac Toe for CodinGame using MCTS in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import math | |
import random | |
from copy import deepcopy | |
from datetime import datetime | |
winning_states = ( # use tuple instead of list for speed | |
0b111_000_000, 0b000_111_000, 0b000_000_111, | |
0b100_100_100, 0b010_010_010, 0b001_001_001, | |
0b100_010_001, 0b001_010_100) | |
def isWinFunction(x): | |
for state in winning_states: | |
if (x & state) == state: | |
return True | |
return False | |
# make a fast lookup for the win conditions | |
isWinLookup = [isWinFunction(x) for x in range(1<<10)] | |
# Game class which holds the game state including the board and LAST player to move | |
class Game: | |
def __init__(self, other=None): | |
self.gameBoard = [0b000_000_000, 0b000_000_000] if other == None else deepcopy(other.gameBoard) | |
self.playerId = 1 if other == None else other.playerId | |
def _getAvailableMoves(self): | |
return (self.gameBoard[0] | self.gameBoard[1]) ^ 0b111_111_111 | |
def getScore(self): | |
return [0.5,1][isWinLookup[self.gameBoard[self.playerId]]] | |
def isGameOver(self): | |
return (self.gameBoard[0] | self.gameBoard[1]) == 0b111_111_111 or self.getScore() == 1 | |
def move(self, move): | |
self.playerId = 1 - self.playerId | |
self.gameBoard[self.playerId] |= move | |
return self | |
def randomMove(self): | |
available = (self.gameBoard[0] | self.gameBoard[1]) ^ 0b111_111_111 | |
choices = [i for i in (0b1, 0b10, 0b100, 0b1_000, 0b10_000, 0b100_000, 0b1_000_000, 0b10_000_000, 0b100_000_000) if (i & available) == i] | |
move = random.choice(choices) | |
self.move(move) | |
return self | |
def print(self): | |
print( "'{0:9b}'".format(self.gameBoard[0]) + " " + "'{0:9b}'".format(self.gameBoard[1]) + " " + str(self.playerId) ) | |
def gameTest(): | |
# test Game._getAvailableMoves and Game.move | |
assert(Game()._getAvailableMoves() == 0b111_111_111) | |
assert(Game().move(0b1)._getAvailableMoves() == 0b111_111_110) | |
assert(Game().move(0b1).move(0b100)._getAvailableMoves() == 0b111_111_010) | |
# test Game.getScore | |
assert(Game().getScore() == 0.5) | |
assert(Game().move(0b1).getScore() == 0.5) | |
assert(Game().move(0b1).move(0b01).move(0b1000).move(0b001).getScore() == 0.5) | |
assert(Game().move(0b1).move(0b01).move(0b1000).move(0b001).move(0b1000000).getScore() == 1) | |
assert(Game().move(0b1).move(0b01).move(0b1000).move(0b001).move(0b1000000).getScore() == 1) | |
# test Game.isGameOver | |
assert(Game().isGameOver() == False) | |
assert(Game().move(0b1).move(0b01).move(0b1000).move(0b001).move(0b1000000).isGameOver() == True) | |
# test Game.randomMove | |
assert(Game().randomMove()._getAvailableMoves() != 0) | |
temp = Game() | |
temp.gameBoard[0] = 0b111_011_111 | |
assert(temp.randomMove()._getAvailableMoves() == 0) | |
gameTest() | |
UCB_C = 5.0 | |
class Node: | |
def __init__(self, parent = None, action = 0): | |
self.parent = parent | |
self.action = action | |
self.playerId = 1 if parent is None else 1 - parent.playerId | |
self.visits = 0 | |
self.score = 0 | |
self.children = None | |
def expand(self, game): | |
if self.children is None: | |
legalMoves = (game.gameBoard[0] | game.gameBoard[1]) ^ 0b111_111_111 | |
self.children = [Node(self, i) for i in (0b1, 0b10, 0b100, 0b1_000, 0b10_000, 0b100_000, 0b1_000_000, 0b10_000_000, 0b100_000_000) if i & legalMoves == i] | |
return self.children | |
def select(self, game): | |
bestChild = self | |
while bestChild.children is not None and len(bestChild.children) > 0: | |
parent = bestChild | |
sqrtLogParentVisits = math.log(parent.visits+1)**0.5 | |
bestScore = -1e99 | |
for child in parent.children: | |
if child.visits == 0: | |
game.move(child.action) | |
return child | |
invSqrtVisits = math.sqrt(1/child.visits) | |
newScore = (child.score*invSqrtVisits+ UCB_C*sqrtLogParentVisits)*invSqrtVisits | |
if newScore > bestScore: | |
bestChild = child | |
bestScore = newScore | |
game.move(bestChild.action) | |
return bestChild | |
def rollout(self, game): | |
while not game.isGameOver(): | |
game.randomMove() | |
score = game.getScore() | |
return score if self.playerId == game.playerId else 1 - score | |
def backpropagate(self, result): | |
curr = self | |
while curr != None: | |
curr.score += result | |
result = 1 - result | |
curr.visits +=1 | |
curr = curr.parent | |
def MCTS(self, startGame): | |
self.playerId = startGame.playerId | |
start_time = datetime.now() | |
iterations = 0 | |
while (datetime.now() - start_time).total_seconds() * 1000 < 95: | |
game = Game(startGame) | |
curr = self.select(game) | |
curr.expand(game) | |
curr = curr.select(game) | |
outcome = curr.rollout(game) | |
curr.backpropagate(outcome) | |
iterations += 1 | |
bestScore = -1e99 | |
bestChild = None | |
for child in self.children: | |
if (child.score/child.visits) > bestScore: | |
bestChild = child | |
bestScore = child.score/child.visits | |
print("action", "'{0:9b}'".format(child.action), "visits", child.visits, f" with {int((child.score/child.visits)*100)} % win",file=sys.stderr, flush=True) | |
print(f"Completed {iterations} iterations",file=sys.stderr, flush=True) | |
return bestChild | |
def nodeTest(): | |
child = Node().MCTS(Game()) | |
print("Best action", "'{0:9b}'".format(child.action), f" with {int((child.score/child.visits)*100)} % chance for win",file=sys.stderr, flush=True) | |
assert(child.action == 0b10000) | |
print("",file=sys.stderr, flush=True) | |
child = Node().MCTS(Game().move(1)) | |
print("Best action", "'{0:9b}'".format(child.action), f" with {int((child.score/child.visits)*100)} % chance for win",file=sys.stderr, flush=True) | |
assert(child.action == 0b10000) | |
nodeTest() | |
def playCodinGameMulti(): | |
game = Game() | |
while(True): | |
r,c = map(int,input().split(" ")) | |
for _ in range(int(input())) : | |
ignore = input() | |
if r != -1: | |
game.move(1<<(r*3+c)) | |
child = Node().MCTS(game) | |
game.move(child.action) | |
action = int(math.log(child.action, 2)) | |
print(f"{action//3} {action%3}") | |
# uncomment to play game on codingame | |
# playCodinGameMulti() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment