Created
July 12, 2022 20:36
-
-
Save xtpor/d207349bd51e7d3ca5f86fd3faf29a47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta http-equiv="X-UA-Compatible" content="IE=edge"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Connect 4!</title> | |
</head> | |
<body> | |
Please use the console. | |
<div> | |
<button onclick="advance(0)">0</button> | |
<button onclick="advance(1)">1</button> | |
<button onclick="advance(2)">2</button> | |
<button onclick="advance(3)">3</button> | |
<button onclick="advance(4)">4</button> | |
<button onclick="advance(5)">5</button> | |
<button onclick="advance(6)">6</button> | |
</div> | |
<div> | |
<button onclick="pp()">print</button> | |
<button onclick="think(1000)">think</button> | |
<button onclick="suggestion()">suggest</button> | |
</div> | |
<script> | |
// https://jonathan-hui.medium.com/monte-carlo-tree-search-mcts-in-alphago-zero-8a403588276a | |
const PLAYER1 = "P1" | |
const PLAYER2 = "P2" | |
const WIN = "WIN" | |
const LOSE = "LOSE" | |
const DRAW = "DRAW" | |
const EXPLORATION_FACTOR = 0.9 | |
function pickRandomArrayElement(array) { | |
const index = Math.floor(Math.random() * array.length) | |
return array[index] | |
} | |
function makeGame() { | |
return { | |
kind: "game", | |
player: PLAYER1, | |
board: makeBoard(), | |
} | |
} | |
function copyGame(game) { | |
return { | |
player: game.player, | |
board: copyBoard(game.board), | |
} | |
} | |
function makeBoard() { | |
const data = [] | |
for (let i = 0; i < 6 * 7; i += 1) { | |
data.push(null) | |
} | |
return { kind: "board", data } | |
} | |
function copyBoard(board) { | |
return { kind: "board", data: [...board.data] } | |
} | |
function boardGet(board, [x, y]) { | |
if (!(0 <= x && x < 7)) return null | |
if (!(0 <= y && y < 6)) return null | |
return board.data[x + y * 7] | |
} | |
function boardPut(board, [x, y], item) { | |
if (typeof x !== "number") throw new Error("assertion error") | |
if (typeof y !== "number") throw new Error("assertion error") | |
if (!(0 <= x && x < 7)) throw new Error("assertion error") | |
if (!(0 <= y && y < 6)) throw new Error("assertion error") | |
return board.data[x + y * 7] = item | |
} | |
function prettyPrintBoard(board) { | |
let s = "" | |
s += " |0 1 2 3 4 5 6 \n" | |
s += "-+--------------\n" | |
for (let j = 0; j < 6; j += 1) { | |
s += j + "|" | |
for (let i = 0; i < 7; i += 1) { | |
const item = boardGet(board, [i, j]) | |
if (item === PLAYER1) s += "A " | |
if (item === PLAYER2) s += "B " | |
if (item === null) s += ". " | |
} | |
s += "\n" | |
} | |
return s | |
} | |
function availableMoves(game) { | |
const moves = [] | |
for (let i = 0; i < 7; i += 1) { | |
if (!boardGet(game.board, [i, 0])) { | |
moves.push(i) | |
} | |
} | |
return moves | |
} | |
function makeMove(game, move) { | |
if (!(0 <= move && move < 7)) { | |
throw new Error("assertion error, invalid move") | |
} | |
for (let spot = [move, 5]; spot[1] >= 0; spot[1] -= 1) { | |
if (!boardGet(game.board, spot)) { | |
boardPut(game.board, spot, game.player) | |
game.player = opponentOf(game.player) | |
return | |
} | |
} | |
throw new Error("Assertion error") | |
} | |
function opponentOf(player) { | |
if (player === PLAYER1) return PLAYER2 | |
if (player === PLAYER2) return PLAYER1 | |
throw new Error("Assertion error") | |
} | |
function checkPattern(game, positions) { | |
const anchor = boardGet(game.board, positions[0]) | |
if (anchor === null) { | |
return null | |
} | |
for (let i = 1; i < positions.length; i += 1) { | |
const item = boardGet(game.board, positions[i]) | |
if (item !== anchor) { | |
return null | |
} | |
} | |
return anchor | |
} | |
function winnerOf(game) { | |
let r | |
for (let i = 0; i < 7; i += 1) { | |
for (let j = 0; j < 6; j += 1) { | |
const horizontal = [[i, j], [i + 1, j], [i + 2, j], [i + 3, j]] | |
if (r = checkPattern(game, horizontal)) return r | |
const vertical = [[i, j], [i, j + 1], [i, j + 2], [i, j + 3]] | |
if (r = checkPattern(game, vertical)) return r | |
const diagonal = [[i, j], [i + 1, j + 1], [i + 2, j + 2], [i + 3, j + 3]] | |
if (r = checkPattern(game, diagonal)) return r | |
const antidiagonal = [[i, j], [i + 1, j - 1], [i + 2, j - 2], [i + 3, j - 3]] | |
if (r = checkPattern(game, antidiagonal)) return r | |
} | |
} | |
return null | |
} | |
function makeTreeNode(parent, game) { | |
const children = {} | |
if (!winnerOf(game)) { | |
for (const m of availableMoves(game)) { | |
children[m] = null | |
} | |
} | |
return { | |
kind: "node", | |
parent, | |
children, | |
game: game, | |
totalScore: 0, | |
totalSimulationCount: 0, | |
} | |
} | |
function isLeafNode(node) { | |
const children = Object.values(node.children) | |
if (children.length === 0) { | |
return true | |
} | |
for (const child of children) { | |
if (!child) { | |
return true | |
} | |
} | |
return false | |
} | |
function isTerminalNode(node) { | |
const children = Object.values(node.children) | |
return children.length === 0 | |
} | |
function computeFitness(parentNode, node) { | |
if (node.totalSimulationCount === 0) { | |
throw new Error("assertion error") | |
} | |
return (node.totalScore / node.totalSimulationCount) + EXPLORATION_FACTOR * (Math.log2(parentNode.totalSimulationCount) / node.totalSimulationCount) | |
} | |
function select(node) { | |
if (!node) { | |
throw new Error("assertion error") | |
} | |
for (let i = 0; i < 100; i += 1) { | |
if (isLeafNode(node)) { | |
return node | |
} | |
const allCandidates = Object.values(node.children) | |
let targetCandidate = allCandidates[0] | |
let targetFitness = computeFitness(node, targetCandidate) | |
for (let i = 1; i < allCandidates.length; i += 1) { | |
const candidate = allCandidates[i] | |
const fitness = computeFitness(node, candidate) | |
if (fitness > targetFitness) { | |
targetCandidate = candidate | |
targetFitness = fitness | |
} | |
} | |
node = targetCandidate | |
} | |
throw new Error("assertion error") | |
} | |
function selectUnexpanded(node) { | |
for (const [move, child] of Object.entries(node.children)) { | |
if (!child) { | |
return Number(move) | |
} | |
} | |
throw new Error("assertion error") | |
} | |
function expand(node, targetMove) { | |
if (!(node && node.kind === "node")) throw new Error("assertion error") | |
const newGame = copyGame(node.game) | |
makeMove(newGame, targetMove) | |
const newNode = makeTreeNode(node, newGame) | |
node.children[targetMove] = newNode | |
return newNode | |
} | |
function simulate(game) { | |
/* impose a max iteration count for safety */ | |
for (let i = 0; i < 100; i += 1) { | |
const winner = winnerOf(game) | |
if (winner === PLAYER1) { | |
return 1 | |
} else if (winner === PLAYER2) { | |
return -1 | |
} | |
const moves = availableMoves(game) | |
if (moves.length === 0) { | |
return 0 | |
} else { | |
makeMove(game, pickRandomArrayElement(moves)) | |
} | |
} | |
throw new Error("assertion error") | |
} | |
function backpropagate(node, score) { | |
while (node) { | |
node.totalScore += score | |
node.totalSimulationCount += 1 | |
node = node.parent | |
} | |
} | |
function stepMonteCarloTreeSearch(tree) { | |
let node = select(tree) | |
if (!isTerminalNode(node)) { | |
node = expand(node, selectUnexpanded(node)) | |
} | |
const score = simulate(copyGame(node.game)) | |
backpropagate(node, score) | |
} | |
let tree = makeTreeNode(null, makeGame()) | |
// PUBLIC API FOR TESTING | |
function think(iterations = 1000) { | |
console.time("MCTS") | |
for (let i = 0; i < iterations; i += 1) { | |
stepMonteCarloTreeSearch(tree) | |
} | |
console.timeEnd("MCTS") | |
} | |
function suggestion() { | |
console.log("AI policy table") | |
const table = [] | |
for (const [move, node] of Object.entries(tree.children)) { | |
table.push({ | |
move: move, | |
meanScore: (node.totalScore / node.totalSimulationCount).toFixed(4), | |
totalSimulationCount: node.totalSimulationCount | |
}) | |
} | |
console.table(table) | |
} | |
function advance(move) { | |
console.log("the move", move, "has been made") | |
if (tree.children[move]) { | |
tree = tree.children[move] | |
tree.parent = null | |
} else { | |
tree = expand(tree, move) | |
tree.parent = null | |
} | |
} | |
function pp() { | |
console.log("Player:", tree.game.player) | |
console.log(prettyPrintBoard(tree.game.board)) | |
} | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment