Skip to content

Instantly share code, notes, and snippets.

@kirisakow
Last active March 9, 2024 22:22
Show Gist options
  • Save kirisakow/7af90f26a8bc3f674058ddda71c6f518 to your computer and use it in GitHub Desktop.
Save kirisakow/7af90f26a8bc3f674058ddda71c6f518 to your computer and use it in GitHub Desktop.
Building a homemade GPT from scratch. Engineering a modern NLP AI model, based on transformers and attention.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [
{
"file_id": "https://gist.github.com/kirisakow/7af90f26a8bc3f674058ddda71c6f518",
"timestamp": "1710022634365"
}
],
"collapsed_sections": [
"LgjYFf9l4rTB",
"vVSK4Ow4h3w2",
"GJf0Jd5esEA-",
"vnPYzC3StW48",
"Rp2mQrOwuF6J",
"Fz-LPul00G_u",
"lqGcHme1-C78",
"oy5KNILoYGVY",
"JoAeSYDjocBZ",
"pA6FtZzS7EIO",
"YSbZMCte7q32",
"MFMwmN7K8IRq",
"gYozYnOQ8Wru",
"382xJUI88xic"
],
"authorship_tag": "ABX9TyMIUDDJkOfQJRkL+L4n7Vom",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/kirisakow/7af90f26a8bc3f674058ddda71c6f518/build-a-homemade-gpt-from-scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Building a homemade GPT from scratch\n",
"\n",
"Engineering a modern NLP AI model, based on transformers and attention.\n",
"<br><br>\n",
"\n",
"This is a refactored version of [NeetCode ML tutorial][neetcode] and [original notebook][original_notebook], enhanced with coding best practices (constants instead of hardcoded “magic” literals; reusable code wrapped in separate functions; extensive use of generators; etc).\n",
"\n",
"Other resources:\n",
"* [PyTorch][pytorch_docs] library documentation;\n",
"* Andrej Karpathy's [Machine Learning Practice Problems][karpathy_playlist] YouTube playlist;\n",
"* [Modern Approaches in Natural Language Processing][modern_nlp], a 2020 online ebook;\n",
"* [Python][python_docs] official documentation;\n",
"* [Effective Python][effective_python], 2nd ed., a 2019 book by Brett Slatkin;\n",
"* A comprehensive curated list of [software engineering best practices and concepts][best_practices] such as clean code, simple design, software craftsmanship, YAGNI, KISS, SOLID, TDD, design patterns, and others.\n",
"\n",
"Author: [Kiril Isakov][kisakov_linkedin] ([kirisakow][kirisakow_github])\n",
"\n",
"[kisakov_linkedin]: https://www.linkedin.com/in/kisakov/\n",
"[kirisakow_github]: https://github.com/kirisakow\n",
"[neetcode]: https://neetcode.io/practice\n",
"[original_notebook]: https://colab.research.google.com/drive/1L92UwfFlVlog-p8PhKe-ioROAaBa4aFc\n",
"[karpathy_playlist]: https://www.youtube.com/playlist?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ\n",
"[pytorch_docs]: https://pytorch.org/docs/stable/\n",
"[modern_nlp]: https://slds-lmu.github.io/seminar_nlp_ss20\n",
"[python_docs]: https://docs.python.org/\n",
"[effective_python]: https://effectivepython.com\n",
"[best_practices]: https://gitlab.com/kirisakow/clean-code-software-craftsmanship-best-practices"
],
"metadata": {
"id": "fvuSVgapabdw"
}
},
{
"cell_type": "markdown",
"source": [
"#### Install and initialize libraries and constants"
],
"metadata": {
"id": "LgjYFf9l4rTB"
}
},
{
"cell_type": "code",
"source": [
"!pip install --quiet torchtyping"
],
"metadata": {
"id": "NPNprFOc0rZs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from numpy.typing import NDArray\n",
"from torchtyping import TensorType\n",
"from typing import List, Tuple\n",
"import itertools\n",
"import numpy as np\n",
"import re\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"COLS_DIM_INDEX, ROWS_DIM_INDEX = 0, 1\n",
"DECIMAL_PRECISION = 4\n",
"DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
"DROPOUT_PROBABILITY = 0.2\n",
"EMB_LAYER_LEN = 16\n",
"FINAL_LAYER_LEN = 10\n",
"FIRST_LAYER_LEN = 512\n",
"IMG_SHAPE = [28, 28]\n",
"KNOWN_OUTPUT_DIMENSION_SIZE = 2\n",
"LEARNING_RATE = 0.01\n",
"LINEAR_NN_SCALE_FACTOR = 4\n",
"MASK_FILLER = 0\n",
"OUTPUT_LAYER_LEN = 1\n",
"T_DIM_INDEX = 1\n",
"TOKEN_DELIMITER = r'\\s'\n",
"WORDS_INDEX = {0: '\\n', 1: ' ', 2: '!', 3: '\"', 4: '$', 5: '%', 6: '&', 7: \"'\", 8: '(', 9: ')', 10: '*',\n",
" 11: '+', 12: ',', 13: '-', 14: '.', 15: '/', 16: '0', 17: '1', 18: '2', 19: '3', 20: '4',\n",
" 21: '5', 22: '6', 23: '7', 24: '8', 25: '9', 26: ':', 27: ';', 28: '?', 29: 'A', 30: 'B',\n",
" 31: 'C', 32: 'D', 33: 'E', 34: 'F', 35: 'G', 36: 'H', 37: 'I', 38: 'J', 39: 'K', 40: 'L',\n",
" 41: 'M', 42: 'N', 43: 'O', 44: 'P', 45: 'Q', 46: 'R', 47: 'S', 48: 'T', 49: 'U', 50: 'V',\n",
" 51: 'W', 52: 'X', 53: 'Y', 54: 'Z', 55: '[', 56: ']', 57: '_', 58: 'a', 59: 'b', 60: 'c',\n",
" 61: 'd', 62: 'e', 63: 'f', 64: 'g', 65: 'h', 66: 'i', 67: 'j', 68: 'k', 69: 'l', 70: 'm',\n",
" 71: 'n', 72: 'o', 73: 'p', 74: 'q', 75: 'r', 76: 's', 77: 't', 78: 'u', 79: 'v', 80: 'w',\n",
" 81: 'x', 82: 'y', 83: 'z', 84: '{', 85: '|', 86: '}', 87: 'à', 88: 'á', 89: 'è', 90: 'é',\n",
" 91: 'ë', 92: 'ñ', 93: 'ó', 94: 'ú', 95: '\\u2005', 96: '–', 97: '—', 98: '‘', 99: '’', 100: '“',\n",
" 101: '”', 102: '…', 103: '\\u205f'}\n",
"WORDS_INDEX = tuple(WORDS_INDEX.values())"
],
"metadata": {
"id": "sB6upU1mr7g_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 0. ML, NLP, and PyTorch fundamentals"
],
"metadata": {
"id": "vVSK4Ow4h3w2"
}
},
{
"cell_type": "markdown",
"source": [
"### 0.1. Gradient Descent"
],
"metadata": {
"id": "wdJ4NLBGiuIy"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "GJf0Jd5esEA-"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A8Sk92kFYhQI"
},
"outputs": [],
"source": [
"# NeetCode solution: https://neetcode.io/problems/gradient-descent\n",
"class Solution:\n",
" def get_minimizer(self, iterations: int, learning_rate: float, init: int) -> float:\n",
" minimizer = init\n",
" for _ in range(iterations):\n",
" derivative = 2 * minimizer\n",
" minimizer = minimizer - learning_rate * derivative\n",
" return round(minimizer, 5)"
]
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n",
"* recursivity;"
],
"metadata": {
"id": "rCbwYQ6Cs7UE"
}
},
{
"cell_type": "code",
"source": [
"class Solution:\n",
"\n",
" def get_minimizer(self, iterations: int, learning_rate: float, init: int) -> float:\n",
" \"\"\"Recursively perform a gradient descent toward the minimum of the x² function, with\n",
" * `iterations`: number of steps;\n",
" * `learning_rate`: step width;\n",
" * `init`: current minimum\n",
" \"\"\"\n",
" if iterations == 0:\n",
" return round(init, DECIMAL_PRECISION)\n",
" derivative = 2 * init\n",
" current_guess = init - learning_rate * derivative\n",
" return self.get_minimizer(init=current_guess,\n",
" iterations=iterations - 1,\n",
" learning_rate=learning_rate)"
],
"metadata": {
"id": "UU87zm7ej_Uf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.2. Linear regression: the `forward()` function"
],
"metadata": {
"id": "FTnOHMUlmxC5"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "vnPYzC3StW48"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/linear-regression-forward\n",
"class Solution:\n",
"\n",
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" prediction = np.matmul(X, weights)\n",
" return np.round(prediction, 5)\n",
"\n",
" def get_error(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64]) -> float:\n",
" error = np.mean(np.square(model_prediction - ground_truth))\n",
" return round(error, 5)"
],
"metadata": {
"id": "eGUjiRAzoz7C"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;"
],
"metadata": {
"id": "qL4xW9kAt15M"
}
},
{
"cell_type": "code",
"source": [
"class Solution:\n",
"\n",
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" prediction = np.matmul(X, weights)\n",
" return np.round(prediction, DECIMAL_PRECISION)\n",
"\n",
" def get_error(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64]) -> float:\n",
" error = np.mean(np.square(model_prediction - ground_truth))\n",
" return round(error, DECIMAL_PRECISION)"
],
"metadata": {
"id": "dKrLDvj6pmZ5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.3. Linear regression: training"
],
"metadata": {
"id": "YO5QRRrVqnzg"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "Rp2mQrOwuF6J"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/linear-regression-training\n",
"class Solution:\n",
" def get_derivative(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64], N: int, X: NDArray[np.float64], desired_weight: int) -> float:\n",
" return -2 * np.dot(ground_truth - model_prediction, X[:, desired_weight]) / N\n",
"\n",
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" return np.squeeze(np.matmul(X, weights))\n",
"\n",
" learning_rate = 0.01\n",
"\n",
" def train_model(self, X: NDArray[np.float64], Y: NDArray[np.float64], num_iterations: int, initial_weights: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" for _ in range(num_iterations):\n",
" model_prediction = self.get_model_prediction(X, initial_weights)\n",
"\n",
" d1 = self.get_derivative(model_prediction, Y, len(X), X, 0)\n",
" d2 = self.get_derivative(model_prediction, Y, len(X), X, 1)\n",
" d3 = self.get_derivative(model_prediction, Y, len(X), X, 2)\n",
"\n",
" initial_weights[0] = initial_weights[0] - d1 * self.learning_rate\n",
" initial_weights[1] = initial_weights[1] - d2 * self.learning_rate\n",
" initial_weights[2] = initial_weights[2] - d3 * self.learning_rate\n",
"\n",
" return np.round(initial_weights, 5)"
],
"metadata": {
"id": "HSnokAW_rFjR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n",
"* wrap repeated instructions into separate functions for better readability and reusability;"
],
"metadata": {
"id": "o1IjpcU5u2j1"
}
},
{
"cell_type": "code",
"source": [
"class Solution:\n",
"\n",
" def get_derivative(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64], N: int, X: NDArray[np.float64], desired_weight: int) -> float:\n",
" return -2 * np.dot(ground_truth - model_prediction, X[:, desired_weight]) / N\n",
"\n",
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" return np.squeeze(np.matmul(X, weights))\n",
"\n",
" def update_weights(self, actual_weights: NDArray[np.float64], model_prediction: NDArray[np.float64],\n",
" Y: NDArray[np.float64], X: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" weights_indices = range(len(X[0]))\n",
" for i in weights_indices:\n",
" derivative = self.get_derivative(model_prediction, Y, len(X), X, i)\n",
" actual_weights[i] -= derivative * LEARNING_RATE\n",
" return actual_weights\n",
"\n",
" def train_model(self, X: NDArray[np.float64], Y: NDArray[np.float64], num_iterations: int, initial_weights: NDArray[np.float64]) -> NDArray[np.float64]:\n",
" actual_weights = initial_weights\n",
" for _ in range(num_iterations):\n",
" model_prediction = self.get_model_prediction(X, actual_weights)\n",
" actual_weights = self.update_weights(actual_weights, model_prediction, Y, X)\n",
" return np.round(actual_weights, DECIMAL_PRECISION)\n"
],
"metadata": {
"id": "56aHIdUyriJC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.4. PyTorch basics"
],
"metadata": {
"id": "cPN38c420BNN"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "Fz-LPul00G_u"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/basics-of-pytorch\n",
"class Solution:\n",
" def reshape(self, to_reshape: TensorType[float]) -> TensorType[float]:\n",
" M, N = to_reshape.shape\n",
" reshaped = torch.reshape(to_reshape, (M * N // 2, 2))\n",
" return torch.round(reshaped, decimals=4)\n",
"\n",
" def average(self, to_avg: TensorType[float]) -> TensorType[float]:\n",
" averaged = torch.mean(to_avg, dim = 0)\n",
" return torch.round(averaged, decimals=4)\n",
"\n",
" def concatenate(self, cat_one: TensorType[float], cat_two: TensorType[float]) -> TensorType[float]:\n",
" concatenated = torch.cat((cat_one, cat_two), dim = 1)\n",
" return torch.round(concatenated, decimals=4)\n",
"\n",
" def get_loss(self, prediction: TensorType[float], target: TensorType[float]) -> TensorType[float]:\n",
" loss = torch.nn.functional.mse_loss(prediction, target)\n",
" return torch.round(loss, decimals=4)\n"
],
"metadata": {
"id": "N6O0I5g80KAJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n",
"* better use of torch API (see `reshape()` function): `-1` is a built-in shortcut for `M⋅N // the_other_output_dim`;"
],
"metadata": {
"id": "FM7sOJG20KqG"
}
},
{
"cell_type": "code",
"source": [
"class Solution:\n",
"\n",
" def reshape(self, to_reshape: TensorType[float]) -> TensorType[float]:\n",
" \"\"\"Reshape an M×N tensor into a (M⋅N // 2)×2 tensor\"\"\"\n",
" reshaped = to_reshape.view(-1, KNOWN_OUTPUT_DIMENSION_SIZE)\n",
" return torch.round(reshaped, decimals=DECIMAL_PRECISION)\n",
"\n",
"\n",
" def average(self, to_avg: TensorType[float]) -> TensorType[float]:\n",
" \"\"\"Find the average of every column in a tensor.\"\"\"\n",
" averaged = torch.mean(to_avg, dim=COLS_DIM_INDEX)\n",
" return torch.round(averaged, decimals=DECIMAL_PRECISION)\n",
"\n",
"\n",
" def concatenate(self, cat_one: TensorType[float], cat_two: TensorType[float]) -> TensorType[float]:\n",
" \"\"\"Combine an M×N tensor and a M×M tensor into a M×(M+N) tensor\"\"\"\n",
" concatenated = torch.cat((cat_one, cat_two), dim=ROWS_DIM_INDEX)\n",
" return torch.round(concatenated, decimals=DECIMAL_PRECISION)\n",
"\n",
"\n",
" def get_loss(self, prediction: TensorType[float], target: TensorType[float]) -> TensorType[float]:\n",
" \"\"\"Calculate the mean squared error loss between a prediction and target tensor\"\"\"\n",
" loss = torch.nn.functional.mse_loss(prediction, target)\n",
" return torch.round(loss, decimals=DECIMAL_PRECISION)"
],
"metadata": {
"id": "aYFbq5oS0NxC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.5. Handwritten digits classifier (based on MNIST dataset)"
],
"metadata": {
"id": "e_tW-g-_46Mw"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "G5SqnDMh47sN"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/handwritten-digit-classifier\n",
"class Solution(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.first_linear = nn.Linear(784, 512)\n",
" self.relu = nn.ReLU()\n",
" self.dropout = nn.Dropout(p=0.2)\n",
" self.projection = nn.Linear(512, 10)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, images: TensorType[float]) -> TensorType[float]:\n",
" torch.manual_seed(0)\n",
" out = self.sigmoid(self.projection(self.dropout(self.relu(self.first_linear(images)))))\n",
" return torch.round(out, decimals=4)"
],
"metadata": {
"id": "Cl2CAsrb5PzW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n",
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n",
"* wrap repeated instructions into separate functions for better readability and reusability;"
],
"metadata": {
"id": "m_91hkSu5lKp"
}
},
{
"cell_type": "code",
"source": [
"class Solution(nn.Module):\n",
"\n",
" def get_img_area(self, img_shape: list) -> int:\n",
" img_shape_as_tensor = torch.tensor(img_shape)\n",
" return torch.prod(img_shape_as_tensor)\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.first_linear = nn.Linear(self.get_img_area(IMG_SHAPE), FIRST_LAYER_LEN)\n",
" self.relu = nn.ReLU()\n",
" self.dropout = nn.Dropout(p=DROPOUT_PROBABILITY)\n",
" self.projection = nn.Linear(FIRST_LAYER_LEN, FINAL_LAYER_LEN)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, images: TensorType[float]) -> TensorType[float]:\n",
" torch.manual_seed(0)\n",
" ret = self.first_linear(images)\n",
" ret = self.relu(ret)\n",
" ret = self.dropout(ret)\n",
" ret = self.projection(ret)\n",
" ret = self.sigmoid(ret)\n",
" return torch.round(ret, decimals=DECIMAL_PRECISION)"
],
"metadata": {
"id": "XP14vrnc5pbs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.6. An introduction to natural language processing (NLP)"
],
"metadata": {
"id": "FZUF8EtQ9yis"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "lqGcHme1-C78"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode Solution: https://neetcode.io/problems/nlp-intro\n",
"class Solution:\n",
" def get_dataset(self, positive: List[str], negative: List[str]) -> TensorType[float]:\n",
" # First let's get the total set of words\n",
" words = set()\n",
" combined = positive + negative\n",
" for sentence in combined:\n",
" for word in sentence.split():\n",
" words.add(word)\n",
"\n",
" # Now let's build a mapping\n",
" sorted_list = sorted(list(words))\n",
" word_to_int = {}\n",
" for i, c in enumerate(sorted_list):\n",
" word_to_int[c] = i + 1\n",
"\n",
" # Write encode() which is used to build the dataset\n",
" def encode(sentence):\n",
" integers = []\n",
" for word in sentence.split():\n",
" integers.append(word_to_int[word])\n",
" return integers\n",
"\n",
" var_len_tensors = []\n",
" for sentence in combined:\n",
" var_len_tensors.append(torch.tensor(encode(sentence)))\n",
"\n",
" return nn.utils.rnn.pad_sequence(var_len_tensors, batch_first = True)"
],
"metadata": {
"id": "3UhjU433-GUd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* add a missing import (`typing.List`)\n",
"* put an explicit value for token delimiter, and use a constant for that purpose instead of a hardcoded value (aka “magic string”);\n",
"* use a regex pattern for token delimiter;\n",
"* wrap repeated instructions into separate functions for better readability and reusability;\n",
"* use [generators][term-generator] (instead of filling up a list, intended to be iterated on later);\n",
"* use a function from [`itertools`][itertools], a handy built-in module;\n",
"* use [list comprehension][term-list-comprehension] and [dictionary comprehension][term-dictionary-comprehension] expressions to fill up lists and dictionaries;\n",
"* use [`enumerate(..., start=1)`][enumerate] for `i` to start from 1 from the get-go, instead of repeatedly incrementing `i + 1`;\n",
"* to access a value by key in a dictionary use failsafe method [`dict.get(key[, default_value])`][dict.get] instead of `dict[key]`;\n",
"\n",
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n",
"[itertools]: https://docs.python.org/3/library/itertools.html\n",
"[term-list-comprehension]: https://docs.python.org/3/glossary.html#term-list-comprehension\n",
"[term-dictionary-comprehension]: https://docs.python.org/3/glossary.html#term-dictionary-comprehension\n",
"[enumerate]: https://docs.python.org/3/library/functions.html#enumerate\n",
"[dict.get]: https://docs.python.org/3/library/stdtypes.html#dict.get"
],
"metadata": {
"id": "19Ug2qAE-M4I"
}
},
{
"cell_type": "code",
"source": [
"class Solution:\n",
"\n",
" def tokenize_sentence(self, sentence: str) -> str:\n",
" yield from [word for word in re.split(TOKEN_DELIMITER, sentence) if word != '']\n",
"\n",
" def get_all_words(self, *args) -> str:\n",
" for sentence in itertools.chain(*args):\n",
" yield from self.tokenize_sentence(sentence)\n",
"\n",
" def get_words_index(self, *args) -> dict:\n",
" unique_words = list(set(self.get_all_words(*args)))\n",
" unique_words.sort()\n",
" sorted_words_index = {w: i for i, w in enumerate(unique_words, start=1)}\n",
" return sorted_words_index\n",
"\n",
" def encode_sentence(self, sentence: str, words_index: dict) -> List[int]:\n",
" for word in self.tokenize_sentence(sentence):\n",
" yield words_index.get(word, 0)\n",
"\n",
" def encode_sentences(self, *args) -> TensorType[int]:\n",
" words_index = self.get_words_index(*args)\n",
" for sentence in itertools.chain(*args):\n",
" encoded_sentence = self.encode_sentence(sentence, words_index)\n",
" yield torch.tensor(list(encoded_sentence))\n",
"\n",
" def get_dataset(self, positive: List[str], negative: List[str]) -> TensorType[float]:\n",
" var_len_tensors = self.encode_sentences(positive, negative)\n",
" return nn.utils.rnn.pad_sequence(list(var_len_tensors), batch_first=True)"
],
"metadata": {
"id": "mUCGxSKQ-REH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.7. Sentiment analysis"
],
"metadata": {
"id": "gYwupvPyX8kp"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "oy5KNILoYGVY"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode Solution: https://neetcode.io/problems/sentiment-analysis\n",
"class Solution(nn.Module):\n",
" def __init__(self, vocabulary_size: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.embedding_layer = nn.Embedding(vocabulary_size, 16)\n",
" self.linear_layer = nn.Linear(16, 1)\n",
" self.sigmoid_layer = nn.Sigmoid()\n",
"\n",
" def forward(self, x: TensorType[int]) -> TensorType[float]:\n",
" embeddings = self.embedding_layer(x)\n",
" averaged = torch.mean(embeddings, axis = 1)\n",
" projected = self.linear_layer(averaged)\n",
" return torch.round(self.sigmoid_layer(projected), decimals=4)"
],
"metadata": {
"id": "36JSmn_qYJPS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n",
"* use [`torch.mean(..., dim=...)`][torch.mean] instead of `torch.mean(..., axis=...)` which doesn't exist;\n",
"* shorter and less complex instructions per line of code (which is especially handy for testing).\n",
"\n",
"[torch.mean]: https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean"
],
"metadata": {
"id": "-bk4IcqRYrPv"
}
},
{
"cell_type": "code",
"source": [
"class Solution(nn.Module):\n",
"\n",
" def __init__(self, vocabulary_size: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.embedding_layer = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=EMB_LAYER_LEN)\n",
" self.linear_layer = nn.Linear(EMB_LAYER_LEN, OUTPUT_LAYER_LEN)\n",
" self.sigmoid_layer = nn.Sigmoid()\n",
"\n",
" def forward(self, x: TensorType[int]) -> TensorType[float]:\n",
" ret = self.embedding_layer(x)\n",
" ret = torch.mean(ret, dim=T_DIM_INDEX)\n",
" ret = self.linear_layer(ret)\n",
" ret = self.sigmoid_layer(ret)\n",
" ret = torch.round(ret, decimals=DECIMAL_PRECISION)\n",
" return ret"
],
"metadata": {
"id": "MVrJ3JNgY8qC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 0.8. GPT dataset"
],
"metadata": {
"id": "GmGRqozKlHL9"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode Solution"
],
"metadata": {
"id": "JoAeSYDjocBZ"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/gpt-dataset\n",
"class Solution:\n",
" def batch_loader(self, raw_dataset: str, context_length: int, batch_size: int) -> Tuple[List[List[str]]]:\n",
" torch.manual_seed(0)\n",
" tokenized = raw_dataset.split()\n",
" indices = torch.randint(low=0, high=len(tokenized) - context_length, size=(batch_size,)).tolist()\n",
" X = []\n",
" Y = []\n",
" for idx in indices:\n",
" X.append(tokenized[idx:idx+context_length])\n",
" Y.append(tokenized[idx+1:idx+1+context_length])\n",
" return X, Y"
],
"metadata": {
"id": "lx5Cl8DoohFV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* put an explicit value for token delimiter, and use a constant for that purpose instead of a hardcoded value (aka “magic string”);\n",
"* use a regex pattern for token delimiter;\n",
"* improve readability by breaking down a complex one-liner instruction (`indices = ...`) into multiple lines of code;\n",
"* wrap repeated instructions into separate functions for better readability and reusability: here, X and Y lists are built using same function, called twice, each time with a different value for the `offset=` parameter;\n",
"* use [generators][term-generator] (instead of filling up a list, intended to be iterated on later);\n",
"* use [list comprehension][term-list-comprehension] and [dictionary comprehension][term-dictionary-comprehension] expressions to fill up lists and dictionaries;\n",
"\n",
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n",
"[term-list-comprehension]: https://docs.python.org/3/glossary.html#term-list-comprehension\n",
"[term-dictionary-comprehension]: https://docs.python.org/3/glossary.html#term-dictionary-comprehension"
],
"metadata": {
"id": "EGEjE6X_ovWA"
}
},
{
"cell_type": "code",
"source": [
"class Solution:\n",
"\n",
" def tokenize(self, text: str) -> str:\n",
" yield from [word for word in re.split(TOKEN_DELIMITER, text) if word != '']\n",
"\n",
" def build_batch(self, words: List[str], context_length: int, indices: TensorType[int], offset: int=0) -> List[List[str]]:\n",
" for i in indices:\n",
" yield words[i + offset:i + offset + context_length]\n",
"\n",
" def batch_loader(self, raw_dataset: str,\n",
" context_length: int,\n",
" batch_size: int) -> Tuple[List[List[str]]]:\n",
" words = list(self.tokenize(raw_dataset))\n",
" torch.manual_seed(0)\n",
" indices = torch.randint(low=0,\n",
" high=len(words) - context_length,\n",
" size=(batch_size,))\\\n",
" .tolist()\n",
" X = list(self.build_batch(words, context_length, indices))\n",
" Y = list(self.build_batch(words, context_length, indices, offset=1))\n",
" return X, Y"
],
"metadata": {
"id": "mRg1R4ggozIF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 1. Build a homemade GPT from scratch"
],
"metadata": {
"id": "cCDTG28v6I7v"
}
},
{
"cell_type": "markdown",
"source": [
"### 1.1. Self-attention class"
],
"metadata": {
"id": "0W8A0qXT6R6i"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "pA6FtZzS7EIO"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/self-attention\n",
"class SingleHeadAttention(nn.Module):\n",
"\n",
" def __init__(self, embedding_dim: int, attention_dim: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.key_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n",
" self.query_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n",
" self.value_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" k = self.key_gen(embedded)\n",
" q = self.query_gen(embedded)\n",
" v = self.value_gen(embedded)\n",
"\n",
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n",
" context_length, attention_dim = k.shape[1], k.shape[2]\n",
" scores = scores / (attention_dim ** 0.5)\n",
"\n",
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n",
" mask = lower_triangular == 0\n",
" scores = scores.masked_fill(mask, float('-inf'))\n",
" scores = nn.functional.softmax(scores, dim = 2)\n",
"\n",
" return torch.round(scores @ v, decimals=4)"
],
"metadata": {
"id": "S0AIVkl47J6D"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* better flexibility, reusability, extensibility: add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n",
"* no over-optimization: use `math.sqrt(x)` instead of `x ** 0.5`;\n",
" * Why? Computing a square root with `x ** 0.5` as an attempt at optimization one'd naturally do in C or C++ is actually an example of a technically wrong choice when it comes to Python: in fact `math.sqrt(x)`, which calls a C binary under the hood, is [significantly faster][math.sqrt] than `x ** 0.5` since the earliest release of Python 3.\n",
" * `numpy` module also has an equivalent function and much more. Therefore, use `numpy` if `numpy` has already been loaded earlier or is intended to be loaded later.\n",
"* wrap repeated instructions into separate functions for better readability and reusability;\n",
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n",
"\n",
"[math.sqrt]: https://stackoverflow.com/questions/327002/which-is-faster-in-python-x-5-or-math-sqrtx/327048#327048"
],
"metadata": {
"id": "Sms2SRPj7MAt"
}
},
{
"cell_type": "code",
"source": [
"class MySingleHeadAttention(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int, head_size: int, round_output: bool=False):\n",
" super().__init__()\n",
" # torch.manual_seed(0)\n",
" self.key_layer = nn.Linear(model_dim, head_size, bias=False)\n",
" self.query_layer = nn.Linear(model_dim, head_size, bias=False)\n",
" self.value_layer = nn.Linear(model_dim, head_size, bias=False)\n",
" self.round_output = round_output\n",
"\n",
" def bool_mask(self, context_length: int, filler_val: int, device) -> TensorType[bool]:\n",
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n",
" return (lower_triangular == filler_val).to(device)\n",
"\n",
" which_power_for_e = lambda self, x: np.log(x) if x != 0 else -np.inf\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" k = self.key_layer(embedded)\n",
" q = self.query_layer(embedded)\n",
" v = self.value_layer(embedded)\n",
" context_length, attention_dim = k.shape[1], k.shape[2]\n",
" scores = q @ torch.transpose(k, 1, 2)\n",
" scores = scores / np.sqrt(attention_dim)\n",
" mask = self.bool_mask(context_length, MASK_FILLER, DEVICE)\n",
" scores = scores.masked_fill(mask, self.which_power_for_e(MASK_FILLER))\n",
" scores = nn.functional.softmax(scores, dim=2)\n",
" scores = scores @ v\n",
" scores = torch.round(scores, decimals=DECIMAL_PRECISION) if self.round_output else scores\n",
" return scores"
],
"metadata": {
"id": "js9bPAv47PnB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 1.2. Multi-headed self-attention class"
],
"metadata": {
"id": "1khI9cVg7q3x"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "YSbZMCte7q32"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/multi-headed-self-attention\n",
"class MultiHeadedSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, embedding_dim: int, attention_dim: int, num_heads: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.attention_heads = nn.ModuleList()\n",
" for i in range(num_heads):\n",
" self.attention_heads.append(self.SingleHeadAttention(embedding_dim, attention_dim // num_heads))\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" head_outputs = []\n",
" for head in self.attention_heads:\n",
" head_outputs.append(head(embedded))\n",
" concatenated = torch.cat(head_outputs, dim = 2)\n",
" return torch.round(concatenated, decimals=4)\n",
"\n",
" class SingleHeadAttention(nn.Module):\n",
" def __init__(self, embedding_dim: int, attention_dim: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.key_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n",
" self.query_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n",
" self.value_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" k = self.key_gen(embedded)\n",
" q = self.query_gen(embedded)\n",
" v = self.value_gen(embedded)\n",
"\n",
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n",
" context_length, attention_dim = k.shape[1], k.shape[2]\n",
" scores = scores / (attention_dim ** 0.5)\n",
"\n",
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n",
" mask = lower_triangular == 0\n",
" scores = scores.masked_fill(mask, float('-inf'))\n",
" scores = nn.functional.softmax(scores, dim = 2)\n",
"\n",
" return scores @ v"
],
"metadata": {
"id": "tp5Oxy5V7q35"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* prefer classes aggregation (aka weak coupling) to composition (aka strong coupling): instead of being `MultiHeadedSelfAttention`'s inner class, `SingleHeadAttention` should be independent, which is better for flexibility, reusability, extensibility;\n",
"* add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n",
"* wrap repeated instructions into separate functions for better readability and reusability;\n",
"* use [generators][term-generator] (instead of filling up a list, intended to be iterated on later);\n",
"\n",
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n"
],
"metadata": {
"id": "4vKZznxR7q37"
}
},
{
"cell_type": "code",
"source": [
"class MyMultiHeadedSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int, num_heads: int, round_output: bool=False):\n",
" super().__init__()\n",
" # torch.manual_seed(0)\n",
" head_size = model_dim // num_heads\n",
" self.attention_heads = nn.ModuleList()\n",
" for _ in range(num_heads):\n",
" self.attention_heads.append(MySingleHeadAttention(model_dim, head_size))\n",
" self.compute = nn.Linear(model_dim, model_dim)\n",
" self.dropout = nn.Dropout(p=DROPOUT_PROBABILITY)\n",
" self.round_output = round_output\n",
"\n",
" def get_head_outputs(self, embedded):\n",
" for att_head in self.attention_heads:\n",
" yield att_head(embedded)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" head_outputs = self.get_head_outputs(embedded)\n",
" ret = torch.cat(list(head_outputs), dim=2)\n",
" ret = self.compute(ret)\n",
" ret = self.dropout(ret)\n",
" ret = torch.round(ret, decimals=DECIMAL_PRECISION) if self.round_output else ret\n",
" return ret"
],
"metadata": {
"id": "SkEbE7T17q3-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 1.3. Transformer block class"
],
"metadata": {
"id": "QGZjrjHW8IRg"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "MFMwmN7K8IRq"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/transformer-block\n",
"class TransformerBlock(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int, num_heads: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.mhsa = self.MultiHeadedSelfAttention(model_dim, num_heads)\n",
" self.vanilla_nn = self.VanillaNeuralNetwork(model_dim)\n",
" self.layer_norm_one = nn.LayerNorm(model_dim)\n",
" self.layer_norm_two = nn.LayerNorm(model_dim)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" # Round answer to 4 decimal places\n",
" torch.manual_seed(0)\n",
" embedded = embedded + self.mhsa(self.layer_norm_one(embedded)) # skip connection\n",
" embedded = embedded + self.vanilla_nn(self.layer_norm_two(embedded)) # another skip connection\n",
" return torch.round(embedded, decimals=4)\n",
"\n",
"\n",
" class MultiHeadedSelfAttention(nn.Module):\n",
"\n",
" class SingleHeadAttention(nn.Module):\n",
" def __init__(self, model_dim: int, head_size: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.key_gen = nn.Linear(model_dim, head_size, bias=False)\n",
" self.query_gen = nn.Linear(model_dim, head_size, bias=False)\n",
" self.value_gen = nn.Linear(model_dim, head_size, bias=False)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" k = self.key_gen(embedded)\n",
" q = self.query_gen(embedded)\n",
" v = self.value_gen(embedded)\n",
"\n",
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n",
" context_length, attention_dim = k.shape[1], k.shape[2]\n",
" scores = scores / (attention_dim ** 0.5)\n",
"\n",
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n",
" mask = lower_triangular == 0\n",
" scores = scores.masked_fill(mask, float('-inf'))\n",
" scores = nn.functional.softmax(scores, dim = 2)\n",
"\n",
" return scores @ v\n",
"\n",
" def __init__(self, model_dim: int, num_heads: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.attention_heads = nn.ModuleList()\n",
" for i in range(num_heads):\n",
" self.attention_heads.append(self.SingleHeadAttention(model_dim, model_dim // num_heads))\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" head_outputs = []\n",
" for head in self.attention_heads:\n",
" head_outputs.append(head(embedded))\n",
" concatenated = torch.cat(head_outputs, dim = 2)\n",
" return concatenated\n",
"\n",
" class VanillaNeuralNetwork(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.first_linear_layer = nn.Linear(model_dim, model_dim * 4)\n",
" self.relu = nn.ReLU()\n",
" self.second_linear_layer = nn.Linear(model_dim * 4, model_dim)\n",
" self.dropout = nn.Dropout(0.2) # using p = 0.2\n",
"\n",
" def forward(self, x: TensorType[float]) -> TensorType[float]:\n",
" torch.manual_seed(0)\n",
" return self.dropout(self.second_linear_layer(self.relu(self.first_linear_layer(x))))"
],
"metadata": {
"id": "c8YatG_r8IRt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* prefer classes aggregation (aka weak coupling) to composition (aka strong coupling): each class should be independent from the others, which is better for flexibility, reusability, extensibility;\n",
"* add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n",
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n"
],
"metadata": {
"id": "zx3pgmiF8IRv"
}
},
{
"cell_type": "code",
"source": [
"class MyTransformerBlock(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int, num_heads: int, round_output: bool=False):\n",
" super().__init__()\n",
" # torch.manual_seed(0)\n",
" self.mhsa = MyMultiHeadedSelfAttention(model_dim, num_heads)\n",
" self.vanilla_nn = MyVanillaNeuralNetwork(model_dim)\n",
" self.layer_norm_one = nn.LayerNorm(model_dim)\n",
" self.layer_norm_two = nn.LayerNorm(model_dim)\n",
" self.round_output = round_output\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" # torch.manual_seed(0)\n",
" embedded += self.mhsa(self.layer_norm_one(embedded))\n",
" embedded += self.vanilla_nn(self.layer_norm_two(embedded))\n",
" embedded = torch.round(embedded, decimals=DECIMAL_PRECISION) if self.round_output else embedded\n",
" return embedded\n",
"\n",
"\n",
"class MyVanillaNeuralNetwork(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int):\n",
" super().__init__()\n",
" # torch.manual_seed(0)\n",
" self.first_linear_layer = nn.Linear(model_dim, model_dim * LINEAR_NN_SCALE_FACTOR)\n",
" self.relu = nn.ReLU()\n",
" self.second_linear_layer = nn.Linear(model_dim * LINEAR_NN_SCALE_FACTOR, model_dim)\n",
" self.dropout = nn.Dropout(p=DROPOUT_PROBABILITY)\n",
"\n",
" def forward(self, x: TensorType[float]) -> TensorType[float]:\n",
" # torch.manual_seed(0)\n",
" ret = self.first_linear_layer(x)\n",
" ret = self.relu(ret)\n",
" ret = self.second_linear_layer(ret)\n",
" ret = self.dropout(ret)\n",
" return ret"
],
"metadata": {
"id": "KofH8zTU8IRy"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 1.4. GPT model class"
],
"metadata": {
"id": "gSyqjSYm8Wrq"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "gYozYnOQ8Wru"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/code-gpt\n",
"class GPT(nn.Module):\n",
"\n",
" def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.word_embeddings = nn.Embedding(vocab_size, model_dim)\n",
" self.position_embeddings = nn.Embedding(context_length, model_dim)\n",
" self.transformer_blocks = nn.Sequential()\n",
" for i in range(num_blocks):\n",
" self.transformer_blocks.append(self.TransformerBlock(model_dim, num_heads))\n",
" self.layer_norm_three = nn.LayerNorm(model_dim)\n",
" self.vocab_projection = nn.Linear(model_dim, vocab_size)\n",
"\n",
" def forward(self, context: TensorType[int]) -> TensorType[float]:\n",
" torch.manual_seed(0)\n",
" embedded = self.word_embeddings(context)\n",
" context_length = context.shape[1]\n",
" positions = torch.arange(context_length)\n",
" embedded = embedded + self.position_embeddings(positions)\n",
"\n",
" raw_output = self.vocab_projection(self.layer_norm_three(self.transformer_blocks(embedded)))\n",
" # raw_output is batch by context_length by vocab_size\n",
"\n",
" probabilities = nn.functional.softmax(raw_output, dim = -1)\n",
" return torch.round(probabilities, decimals=4)\n",
"\n",
" class TransformerBlock(nn.Module):\n",
"\n",
" class MultiHeadedSelfAttention(nn.Module):\n",
"\n",
" class SingleHeadAttention(nn.Module):\n",
" def __init__(self, model_dim: int, head_size: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.key_gen = nn.Linear(model_dim, head_size, bias=False)\n",
" self.query_gen = nn.Linear(model_dim, head_size, bias=False)\n",
" self.value_gen = nn.Linear(model_dim, head_size, bias=False)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" k = self.key_gen(embedded)\n",
" q = self.query_gen(embedded)\n",
" v = self.value_gen(embedded)\n",
"\n",
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n",
" context_length, attention_dim = k.shape[1], k.shape[2]\n",
" scores = scores / (attention_dim ** 0.5)\n",
"\n",
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n",
" mask = lower_triangular == 0\n",
" scores = scores.masked_fill(mask, float('-inf'))\n",
" scores = nn.functional.softmax(scores, dim = 2)\n",
"\n",
" return scores @ v\n",
"\n",
" def __init__(self, model_dim: int, num_heads: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.attention_heads = nn.ModuleList()\n",
" for i in range(num_heads):\n",
" self.attention_heads.append(self.SingleHeadAttention(model_dim, model_dim // num_heads))\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" head_outputs = []\n",
" for head in self.attention_heads:\n",
" head_outputs.append(head(embedded))\n",
" concatenated = torch.cat(head_outputs, dim = 2)\n",
" return concatenated\n",
"\n",
" class VanillaNeuralNetwork(nn.Module):\n",
"\n",
" def __init__(self, model_dim: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.first_linear_layer = nn.Linear(model_dim, model_dim * 4)\n",
" self.relu = nn.ReLU()\n",
" self.second_linear_layer = nn.Linear(model_dim * 4, model_dim)\n",
" self.dropout = nn.Dropout(0.2) # using p = 0.2\n",
"\n",
" def forward(self, x: TensorType[float]) -> TensorType[float]:\n",
" torch.manual_seed(0)\n",
" return self.dropout(self.second_linear_layer(self.relu(self.first_linear_layer(x))))\n",
"\n",
" def __init__(self, model_dim: int, num_heads: int):\n",
" super().__init__()\n",
" torch.manual_seed(0)\n",
" self.mhsa = self.MultiHeadedSelfAttention(model_dim, num_heads)\n",
" self.vanilla_nn = self.VanillaNeuralNetwork(model_dim)\n",
" self.layer_norm_one = nn.LayerNorm(model_dim)\n",
" self.layer_norm_two = nn.LayerNorm(model_dim)\n",
"\n",
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n",
" torch.manual_seed(0)\n",
" embedded = embedded + self.mhsa(self.layer_norm_one(embedded)) # skip connection\n",
" embedded = embedded + self.vanilla_nn(self.layer_norm_two(embedded)) # another skip connection\n",
" return embedded"
],
"metadata": {
"id": "q0h-utP08Wrv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* prefer classes aggregation (aka weak coupling) to composition (aka strong coupling): each class should be independent from the others, which is better for flexibility, reusability, extensibility;\n",
"* add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n",
"* wrap repeated instructions into separate functions for better readability and reusability;\n",
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n",
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n"
],
"metadata": {
"id": "09I7ov_N8Wrv"
}
},
{
"cell_type": "code",
"source": [
"class MyGPT(nn.Module):\n",
"\n",
" def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int, round_output: bool=False):\n",
" super().__init__()\n",
" # torch.manual_seed(0)\n",
" self.token_embedding = nn.Embedding(vocab_size, model_dim)\n",
" self.pos_embedding = nn.Embedding(context_length, model_dim)\n",
" self.transformer_blocks = nn.Sequential()\n",
" for _ in range(num_blocks):\n",
" self.transformer_blocks.append(MyTransformerBlock(model_dim, num_heads))\n",
" self.layer_norm_three = nn.LayerNorm(model_dim)\n",
" self.vocab_projection = nn.Linear(model_dim, vocab_size)\n",
" self.round_output = round_output\n",
"\n",
" def get_positions(self, context, device):\n",
" context_len = context.shape[1]\n",
" return torch.arange(context_len).to(device)\n",
"\n",
" def forward(self, context: TensorType[int]) -> TensorType[float]:\n",
" # torch.manual_seed(0)\n",
" embedded = self.token_embedding(context)\n",
" positions = self.get_positions(context, DEVICE)\n",
" embedded += self.pos_embedding(positions)\n",
" ret = self.transformer_blocks(embedded)\n",
" ret = self.layer_norm_three(ret)\n",
" ret = self.vocab_projection(ret)\n",
" ret = torch.round(ret, decimals=DECIMAL_PRECISION) if self.round_output else ret\n",
" return ret"
],
"metadata": {
"id": "XgubHNYe8Wrw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 1.5. Make GPT talk back"
],
"metadata": {
"id": "fRGhx_by8xiX"
}
},
{
"cell_type": "markdown",
"source": [
"#### NeetCode solution"
],
"metadata": {
"id": "382xJUI88xic"
}
},
{
"cell_type": "code",
"source": [
"# NeetCode solution: https://neetcode.io/problems/make-gpt-talk-back\n",
"class Solution:\n",
" def generate(self, model, new_chars: int, context: TensorType[int], context_length: int, int_to_char: dict) -> str:\n",
" generator = torch.manual_seed(0)\n",
" initial_state = generator.get_state()\n",
" res = []\n",
" for i in range(new_chars):\n",
" if len(context.T) > context_length:\n",
" context = context[:, -context_length:]\n",
" prediction = model(context) # B, T, Vocab_Size\n",
" last_time_step = prediction[:, -1, :] # B, Vocab_Size\n",
" probabilities = nn.functional.softmax(last_time_step, dim = -1)\n",
" next_char = torch.multinomial(probabilities, 1, generator=generator)\n",
" generator.set_state(initial_state)\n",
" context = torch.cat((context, next_char), dim = -1)\n",
" res.append(int_to_char[next_char.item()])\n",
" return ''.join(res)"
],
"metadata": {
"id": "2aYvYuLe8xie"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Refactored solution\n",
"* return a random result or a follow-up to a prompt;\n",
"* use a [generator][term-generator] to yield the result bit by bit, instead of waiting for the result to be complete to return it;\n",
"* wrap repeated instructions into separate functions for better readability and reusability;\n",
"\n",
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n"
],
"metadata": {
"id": "0-kTjPo18xig"
}
},
{
"cell_type": "code",
"source": [
"class Runner:\n",
"\n",
" def encode_prompt(self, prompt: str, words_index: tuple, device: torch.device) -> TensorType[int]:\n",
" context = torch.zeros(1, 1, dtype=torch.int64).to(device)\n",
" for char_to_encode in prompt:\n",
" char_index = words_index.index(char_to_encode)\n",
" context = torch.cat((context, torch.tensor([[next_predicted_index]])), dim=-1)\n",
" return context\n",
"\n",
" def left_trim_context(self, context: TensorType[int], context_length: int) -> TensorType[int]:\n",
" if len(context.T) > context_length:\n",
" context = context[:, -context_length:]\n",
" return context\n",
"\n",
" def generate(self, model, output_len: int, context_length: int, context: TensorType[int]=None) -> int:\n",
" if context is None:\n",
" context = torch.zeros(1, 1, dtype=torch.int64)\n",
" for i in range(output_len):\n",
" context = self.left_trim_context(context, context_length)\n",
" prediction = model(context) # B x T x Vocab_Size\n",
" last_time_step = prediction[:, -1, :]\n",
" probabilities = nn.functional.softmax(last_time_step, dim=-1)\n",
" next_predicted_index = torch.multinomial(probabilities, 1)\n",
" context = torch.cat((context, next_predicted_index), dim=-1)\n",
" yield next_predicted_index.item()"
],
"metadata": {
"id": "82wwnK_U8xii"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 1.6. Download and plug in the pre-trained model"
],
"metadata": {
"id": "vX6JVta_tuCm"
}
},
{
"cell_type": "code",
"source": [
"%cd /content\n",
"!git clone https://github.com/gptandchill/drake-lyric-generator\n",
"%cd drake-lyric-generator\n",
"%ls"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P8wDHQmot0M2",
"outputId": "214f6211-303c-4ed1-e903-d61dec1bd142"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n",
"Cloning into 'drake-lyric-generator'...\n",
"remote: Enumerating objects: 3, done.\u001b[K\n",
"remote: Counting objects: 100% (3/3), done.\u001b[K\n",
"remote: Compressing objects: 100% (2/2), done.\u001b[K\n",
"remote: Total 3 (delta 0), reused 0 (delta 0), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (3/3), 16.53 MiB | 23.51 MiB/s, done.\n",
"/content/drake-lyric-generator\n",
"weights.pt\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Define the hyperparameters, instantiate the model, and load in the weights from training. The prior cell downloads weights.pt into this Colab runtime."
],
"metadata": {
"id": "ckqEPmJUsnfF"
}
},
{
"cell_type": "code",
"source": [
"path_to_pre_trained_model = 'weights.pt'\n",
"vocab_size = len(WORDS_INDEX)\n",
"context_length = 128\n",
"model_dim = 252\n",
"num_blocks = 6\n",
"num_heads = 6\n",
"\n",
"model = MyGPT(vocab_size, context_length, model_dim, num_blocks, num_heads).to(DEVICE)\n",
"pre_trained_model_state_dict = torch.load(path_to_pre_trained_model, map_location=DEVICE)\n",
"model.load_state_dict(pre_trained_model_state_dict)\n",
"_ = model.eval()"
],
"metadata": {
"id": "aBPCi79SskEn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 1.7. Run model to generate lyrics"
],
"metadata": {
"id": "Sijuqs-xMi1n"
}
},
{
"cell_type": "markdown",
"source": [
"#### 1.7.1. Generate random lyrics"
],
"metadata": {
"id": "OmYfLP60tPrS"
}
},
{
"cell_type": "code",
"source": [
"output_len = 150\n",
"for next_predicted_index in Runner().generate(model, output_len, context_length):\n",
" next_char = WORDS_INDEX[next_predicted_index]\n",
" print('', end=next_char)"
],
"metadata": {
"id": "qpwXg0YitK0w",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "17142ad0-6c91-413c-878c-812ba67b55fe"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"You don't know that\n",
"Put the was fuckin' attien\n",
"And I got still like a mil with the fireworks\n",
"Tatin' wifey, they don't even regrin' it so like my kill "
]
}
]
},
{
"cell_type": "markdown",
"source": [
"#### 1.7.2. Generate lyrics as a follow-up to a prompt"
],
"metadata": {
"id": "zscxLW_70gex"
}
},
{
"cell_type": "code",
"source": [
"output_len = 150\n",
"prompt = \"I was born to \"\n",
"context = Runner().encode_prompt(prompt, WORDS_INDEX, DEVICE)\n",
"print(prompt, end='')\n",
"for next_predicted_index in Runner().generate(model, output_len, context_length, context):\n",
" next_char = WORDS_INDEX[next_predicted_index]\n",
" print('', end=next_char)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ff7c5533-e5a7-43bb-f4b6-0a2e19d4297f",
"id": "SsrBY2Wb0gez"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"I was born to anothough, got the just like I love you s\n",
"It's all your soldier flow it\n",
"Uppy, now after tryna distage is to getting if a madrista\n",
"There's so smoker fe"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment