Last active
January 1, 2023 21:53
-
-
Save antoniSea/628ec04e1486df6ccccf5a9e8674e23f to your computer and use it in GitHub Desktop.
chess-ai.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"private_outputs": true, | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/antoniSea/628ec04e1486df6ccccf5a9e8674e23f/chess-ai.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fn0sCWoHYsbB" | |
}, | |
"source": [ | |
"!pip install python-chess==0.31.3" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QauvWk2MkddY" | |
}, | |
"source": [ | |
"import chess\n", | |
"import chess.engine\n", | |
"import random\n", | |
"import numpy\n", | |
"\n", | |
"\n", | |
"# this function will create our x (board)\n", | |
"def random_board(max_depth=200):\n", | |
" board = chess.Board()\n", | |
" depth = random.randrange(0, max_depth)\n", | |
"\n", | |
" for _ in range(depth):\n", | |
" all_moves = list(board.legal_moves)\n", | |
" random_move = random.choice(all_moves)\n", | |
" board.push(random_move)\n", | |
" if board.is_game_over():\n", | |
" break\n", | |
"\n", | |
" return board\n", | |
"\n", | |
"\n", | |
"# this function will create our f(x) (score)\n", | |
"def stockfish(board, depth):\n", | |
" with chess.engine.SimpleEngine.popen_uci('/content/stockfish') as sf:\n", | |
" result = sf.analyse(board, chess.engine.Limit(depth=depth))\n", | |
" score = result['score'].white().score()\n", | |
" return score" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ULOEWyyfYqtq" | |
}, | |
"source": [ | |
"board = random_board()\n", | |
"board" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QtZy4cR8ZMhq" | |
}, | |
"source": [ | |
"print(stockfish(board, 10))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Rdo64dA7dhBE" | |
}, | |
"source": [ | |
"squares_index = {\n", | |
" 'a': 0,\n", | |
" 'b': 1,\n", | |
" 'c': 2,\n", | |
" 'd': 3,\n", | |
" 'e': 4,\n", | |
" 'f': 5,\n", | |
" 'g': 6,\n", | |
" 'h': 7\n", | |
"}\n", | |
"\n", | |
"\n", | |
"# example: h3 -> 17\n", | |
"def square_to_index(square):\n", | |
" letter = chess.square_name(square)\n", | |
" return 8 - int(letter[1]), squares_index[letter[0]]\n", | |
"\n", | |
"\n", | |
"def split_dims(board):\n", | |
" # this is the 3d matrix\n", | |
" board3d = numpy.zeros((14, 8, 8), dtype=numpy.int8)\n", | |
"\n", | |
" # here we add the pieces's view on the matrix\n", | |
" for piece in chess.PIECE_TYPES:\n", | |
" for square in board.pieces(piece, chess.WHITE):\n", | |
" idx = numpy.unravel_index(square, (8, 8))\n", | |
" board3d[piece - 1][7 - idx[0]][idx[1]] = 1\n", | |
" for square in board.pieces(piece, chess.BLACK):\n", | |
" idx = numpy.unravel_index(square, (8, 8))\n", | |
" board3d[piece + 5][7 - idx[0]][idx[1]] = 1\n", | |
"\n", | |
" # add attacks and valid moves too\n", | |
" # so the network knows what is being attacked\n", | |
" aux = board.turn\n", | |
" board.turn = chess.WHITE\n", | |
" for move in board.legal_moves:\n", | |
" i, j = square_to_index(move.to_square)\n", | |
" board3d[12][i][j] = 1\n", | |
" board.turn = chess.BLACK\n", | |
" for move in board.legal_moves:\n", | |
" i, j = square_to_index(move.to_square)\n", | |
" board3d[13][i][j] = 1\n", | |
" board.turn = aux\n", | |
"\n", | |
" return board3d" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gHONl_M1hG9i" | |
}, | |
"source": [ | |
"split_dims(board)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6S7QNZqwmBOP" | |
}, | |
"source": [ | |
"import tensorflow.keras.models as models\n", | |
"import tensorflow.keras.layers as layers\n", | |
"import tensorflow.keras.utils as utils\n", | |
"import tensorflow.keras.optimizers as optimizers\n", | |
"\n", | |
"\n", | |
"def build_model(conv_size, conv_depth):\n", | |
" board3d = layers.Input(shape=(14, 8, 8))\n", | |
"\n", | |
" # adding the convolutional layers\n", | |
" x = board3d\n", | |
" for _ in range(conv_depth):\n", | |
" x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', activation='relu', data_format='channels_first')(x)\n", | |
" x = layers.Flatten()(x)\n", | |
" x = layers.Dense(64, 'relu')(x)\n", | |
" x = layers.Dense(1, 'sigmoid')(x)\n", | |
"\n", | |
" return models.Model(inputs=board3d, outputs=x)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3IjDiS3Bmo5m" | |
}, | |
"source": [ | |
"model = build_model(32, 4)\n", | |
"utils.plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=False)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tAFSOFc8pJf8" | |
}, | |
"source": [ | |
"def build_model_residual(conv_size, conv_depth):\n", | |
" board3d = layers.Input(shape=(14, 8, 8))\n", | |
"\n", | |
" # adding the convolutional layers\n", | |
" x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(board3d)\n", | |
" for _ in range(conv_depth):\n", | |
" previous = x\n", | |
" x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(x)\n", | |
" x = layers.BatchNormalization()(x)\n", | |
" x = layers.Activation('relu')(x)\n", | |
" x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(x)\n", | |
" x = layers.BatchNormalization()(x)\n", | |
" x = layers.Add()([x, previous])\n", | |
" x = layers.Activation('relu')(x)\n", | |
" x = layers.Flatten()(x)\n", | |
" x = layers.Dense(1, 'sigmoid')(x)\n", | |
"\n", | |
" return models.Model(inputs=board3d, outputs=x)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CkOXxmoVyHdc", | |
"cellView": "both" | |
}, | |
"source": [ | |
"import tensorflow.keras.callbacks as callbacks\n", | |
"\n", | |
"\n", | |
"def get_dataset():\n", | |
"\tcontainer = numpy.load('dataset.npz')\n", | |
"\tb, v = container['b'], container['v']\n", | |
"\tv = numpy.asarray(v / abs(v).max() / 2 + 0.5, dtype=numpy.float32) # normalization (0 - 1)\n", | |
"\treturn b, v\n", | |
"\n", | |
"\n", | |
"x_train, y_train = get_dataset()\n", | |
"print(x_train.shape)\n", | |
"print(y_train.shape)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RyOYq9mv2ppC" | |
}, | |
"source": [ | |
"model.compile(optimizer=optimizers.Adam(5e-4), loss='mean_squared_error')\n", | |
"model.summary()\n", | |
"model.fit(x_train, y_train,\n", | |
" batch_size=2048,\n", | |
" epochs=1000,\n", | |
" verbose=1,\n", | |
" validation_split=0.1,\n", | |
" callbacks=[callbacks.ReduceLROnPlateau(monitor='loss', patience=10),\n", | |
" callbacks.EarlyStopping(monitor='loss', patience=15, min_delta=1e-4)])\n", | |
"\n", | |
"model.save('model.h5')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "e4CfjcGorHzg" | |
}, | |
"source": [ | |
"# used for the minimax algorithm\n", | |
"def minimax_eval(board):\n", | |
" board3d = split_dims(board)\n", | |
" board3d = numpy.expand_dims(board3d, 0)\n", | |
" return model.predict(board3d)[0][0]\n", | |
"\n", | |
"\n", | |
"def minimax(board, depth, alpha, beta, maximizing_player):\n", | |
" if depth == 0 or board.is_game_over():\n", | |
" return minimax_eval(board)\n", | |
" \n", | |
" if maximizing_player:\n", | |
" max_eval = -numpy.inf\n", | |
" for move in board.legal_moves:\n", | |
" board.push(move)\n", | |
" eval = minimax(board, depth - 1, alpha, beta, False)\n", | |
" board.pop()\n", | |
" max_eval = max(max_eval, eval)\n", | |
" alpha = max(alpha, eval)\n", | |
" if beta <= alpha:\n", | |
" break\n", | |
" return max_eval\n", | |
" else:\n", | |
" min_eval = numpy.inf\n", | |
" for move in board.legal_moves:\n", | |
" board.push(move)\n", | |
" eval = minimax(board, depth - 1, alpha, beta, True)\n", | |
" board.pop()\n", | |
" min_eval = min(min_eval, eval)\n", | |
" beta = min(beta, eval)\n", | |
" if beta <= alpha:\n", | |
" break\n", | |
" return min_eval\n", | |
"\n", | |
"\n", | |
"# this is the actual function that gets the move from the neural network\n", | |
"def get_ai_move(board, depth):\n", | |
" max_move = None\n", | |
" max_eval = -numpy.inf\n", | |
"\n", | |
" for move in board.legal_moves:\n", | |
" board.push(move)\n", | |
" eval = minimax(board, depth - 1, -numpy.inf, numpy.inf, False)\n", | |
" board.pop()\n", | |
" if eval > max_eval:\n", | |
" max_eval = eval\n", | |
" max_move = move\n", | |
" \n", | |
" return max_move" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "C63ND0E_uffp" | |
}, | |
"source": [ | |
"board = chess.Board()\n", | |
"\n", | |
"while True:\n", | |
" move = get_ai_move(board, 1)\n", | |
" board.push(move)\n", | |
" print(f'\\n{board}')\n", | |
" if board.is_game_over():\n", | |
" break\n", | |
"\n", | |
" move = get_ai_move(board, 1)\n", | |
" board.push(move)\n", | |
" print(f'\\n{board}')\n", | |
" if board.is_game_over():\n", | |
" break" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pyCtTjfI2aaW" | |
}, | |
"source": [ | |
"!cp \"/content/drive/My Drive/dataset.zip\" /content/dataset.zip\n", | |
"!unzip dataset.zip\n", | |
"!rm dataset.zip\n", | |
"!chmod +x stockfish" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O0sjAB1A2Y0U" | |
}, | |
"source": [ | |
"import random\n", | |
"random.seed(37)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment