Created
June 3, 2024 08:13
-
-
Save zaccharieramzi/10e64574e68510b901bd08ba1894e13c to your computer and use it in GitHub Desktop.
attention-scratch.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/zaccharieramzi/10e64574e68510b901bd08ba1894e13c/attention-scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "s_jAcr48kgIj" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"# if torch.backends.mps.is_available(): device = torch.device(\"mps\") # special for Mac" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jjO62A2ikgIh" | |
}, | |
"source": [ | |
"### Multi-head Attention from Scratch\n", | |
"\n", | |
"This notebook implements from scratch, in a step-by-step fashion, a multi-head self-attention layer, which gives the same output as the Pytorch implementation.\n", | |
"\n", | |
"References:\n", | |
"- Original transformer paper: [Attention is all you need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)\n", | |
"- PyTorch implementation is in the function [`multi_head_attention_forward`](\n", | |
"https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py)\n", | |
"\n", | |
"by [Zaccharie Ramzi](https://zaccharieramzi.fr/) and [Gabriel Peyré](http://www.gpeyre.com/)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "fhVrcbzwkgIk" | |
}, | |
"source": [ | |
"The data is composed of batches of $p$ points $(x_s^b)_{s=0}^{p-1}$ in $\\mathbb{R}^d$ stored in a matrix `X` of size $(n_b,p,d)$. Here $n_b$ is the number of batches, so that $b$ runs in $0 \\ldots n_b-1$. These $n_b$ batches are processed in parallel. Note that we use here the \"batch first\" format." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "54FOnOPIkgIk" | |
}, | |
"outputs": [], | |
"source": [ | |
"n_b = 8 # size of batch, processed in parallel\n", | |
"p = 80 # number of points in each points cloud\n", | |
"d = 12 # dimension of the points\n", | |
"X = torch.randn(n_b, p, d, device=device)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kPNMgSPzkgIl" | |
}, | |
"source": [ | |
"Generate the parameter of the attention layer." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "hxP1eBy1kgIl" | |
}, | |
"outputs": [], | |
"source": [ | |
"K = torch.randn(d, d, device=device)\n", | |
"Q = torch.randn(d, d, device=device)\n", | |
"V = torch.randn(d, d, device=device)\n", | |
"L = torch.randn(d, d, device=device)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Tkp5_aamkgIl" | |
}, | |
"source": [ | |
"$\\newcommand{\\coloneqq}{:=}$\n", | |
"First the points are transformed in Keys, Queries, Values using matrices $K \\in \\mathbb{R}^{d \\times d}$, $Q \\in \\mathbb{R}^{d \\times d}$, $V \\in \\mathbb{R}^{d \\times d}$\n", | |
"$$\n", | |
" \\forall s = 0,\\ldots,p-1, \\quad\n", | |
" k_s^b \\coloneqq K x_i^b, \\quad\n", | |
" q_s^b \\coloneqq Q x_i^b, \\quad\n", | |
" v_s^b \\coloneqq V x_i^b.\n", | |
"$$\n", | |
"These points are stored in the arrays `KX,QX,VX` of size $(n_b,p,d)$.\n", | |
"\n", | |
"We use Einstein summation notations to compute the transform, this is very useful and should be prefered over direct array manipulation (transposition, etc)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "zlY6K5kPkgIm" | |
}, | |
"outputs": [], | |
"source": [ | |
"KX = torch.einsum(\"ij,bsj->bsi\", [K, X])\n", | |
"QX = torch.einsum(\"ij,bsj->bsi\", [Q, X])\n", | |
"VX = torch.einsum(\"ij,bsj->bsi\", [V, X])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "4POiS_7mkgIm" | |
}, | |
"source": [ | |
"Then each of these points such as $k_s^b \\in \\mathbb{R}^{d}$ are split into $n_h$ (\"number of heads\") points $k_{s}^{b,h} \\in \\mathbb{R}^{d_h}$ where $d_h \\coloneqq d/n_h$, i.e.\n", | |
"$$\n", | |
" k_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}), \\quad\n", | |
" q_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}), \\quad\n", | |
" v_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}).\n", | |
"$$\n", | |
"These new points are still stored in the same `KX,QX,VX`, but they have size $(n_b,p,n_h,d_h)$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "uWEbCNzZkgIn" | |
}, | |
"outputs": [], | |
"source": [ | |
"n_h = 2 # number of heads\n", | |
"d_h = d // n_h # dimension of each head\n", | |
"KX = KX.reshape(n_b, p, n_h, d_h )\n", | |
"QX = QX.reshape(n_b, p, n_h, d_h )\n", | |
"VX = VX.reshape(n_b, p, n_h, d_h )" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Ew5nPSM3kgIn" | |
}, | |
"source": [ | |
"We then compute, for each head $h=0,\\ldots,n_h-1$, the inner products between the keys and queries\n", | |
"$$\n", | |
" \\forall (s,t) \\in \\{0,\\ldots,p-1\\}^2, \\quad\n", | |
" D_{s,t}^{b,h} \\coloneqq \\langle k_{s}^{b,h}, q_{t}^{b,h} \\rangle_{\\mathbb{R}^{d_h}}\n", | |
"$$\n", | |
"and they are stored in the matrix `D` of size $(n_b,n_h,p,p)$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "HCSTIpqUkgIn" | |
}, | |
"outputs": [], | |
"source": [ | |
"D = torch.einsum(\"bshi,bthi->bhst\", [QX, KX])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ukk_wbOMkgIn" | |
}, | |
"source": [ | |
"From these, one compute the attention kernel $U$ and row-normalize it to obtain $\\tilde U$ stored in `Ut` of size $(n_b,n_h,p,p)$\n", | |
"$$\n", | |
" \\tilde U_{s,t}^{b,h} \\coloneqq \\frac{U_{s,t}^{b,h}}{\\sum_{t'} U_{s,t'}^{b,h}}\n", | |
" \\quad\\text{where}\\quad\n", | |
" U_{s,t}^{b,h} \\coloneqq e^{\\frac{D_{s,t}^{b,h}}{\\sqrt{d_h}}}.\n", | |
"$$\n", | |
"The $1/\\sqrt{d_h}$ scaling is such that, at initialization, if $(K,Q)$ are Gaussian white noise with unit variance, then the entries of $\\tilde U_{s,t}^h$ have roughly the same amplitude, which is important to ease training." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "GmmQmGTIkgIo" | |
}, | |
"outputs": [], | |
"source": [ | |
"r = torch.sqrt(torch.tensor(d_h).double()) # note that this is the per-head dimension and not the full attention dimension\n", | |
"U = torch.exp(D / r)\n", | |
"Ut = U / torch.sum(U, axis=3, keepdim=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "J_iJB21YkgIo" | |
}, | |
"source": [ | |
"This kernel is then used to barycenter the values points to obtains new points\n", | |
"$$\n", | |
" \\forall s = 0,\\ldots,p-1, \\quad\n", | |
" z_{s}^{b,h} \\coloneqq \\sum_{t=0}^{p-1} \\tilde U_{s,t}^{b,h} v_t^b.\n", | |
"$$\n", | |
"These new points are stored in the array `Z` of size $(n_b,p,n_h,d_h)$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "q5tjoIWCkgIo" | |
}, | |
"outputs": [], | |
"source": [ | |
"Z = torch.einsum(\"bhst,bthi->bshi\", [Ut, VX])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "fUpWQMKRkgIo" | |
}, | |
"source": [ | |
"The output of all the heads are then grouped in new points\n", | |
"$$\n", | |
" \\forall s = 0,\\ldots,p-1, \\quad\n", | |
" z_{s}^{b} \\coloneqq (z_{s}^{b,0},\\ldots,z_{s}^{b,n_h-1}) \\in \\mathbb{R}^d.\n", | |
"$$\n", | |
"They are still stored in the same matrix `Z` of size $(n_b,p,d)$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "HFBH7YHakgIo" | |
}, | |
"outputs": [], | |
"source": [ | |
"Z = Z.reshape(n_b, p, n_h*d_h)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "NaCu3ZUCkgIo" | |
}, | |
"source": [ | |
"Then a final linear matrix $L \\in \\mathbb{R}^{d \\times d}$ is applied independantly to each point to obtain the output\n", | |
"$$\n", | |
" \\forall s = 0,\\ldots,p-1, \\quad\n", | |
" y_{s}^{b} \\coloneqq L z_{s}^{b}.\n", | |
"$$\n", | |
"These points are output by the function in an array `Y` of the same size as `X`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "2Ps99ZdakgIp" | |
}, | |
"outputs": [], | |
"source": [ | |
"Y = torch.einsum(\"ij,bsj->bsi\", [L, Z])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-zm1NG5HkgIp" | |
}, | |
"source": [ | |
"Put all this in a function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "YJ3p0Ba6kgIp" | |
}, | |
"outputs": [], | |
"source": [ | |
"def multi_head_attention(X, K, Q, V, L, n_h):\n", | |
" n_b, p, d = X.shape\n", | |
" d_h = d // n_h # dimension of the features of each head\n", | |
" assert( d_h * n_h == d ), \"Embedding size needs to be divisible by heads\"\n", | |
" # apply the matrices K,Q,V to X, and then spread them in the different heads\n", | |
" KX = torch.einsum(\"ij,bsj->bsi\", [K, X]).reshape( n_b, p, n_h, d_h )\n", | |
" QX = torch.einsum(\"ij,bsj->bsi\", [Q, X]).reshape( n_b, p, n_h, d_h )\n", | |
" VX = torch.einsum(\"ij,bsj->bsi\", [V, X]).reshape( n_b, p, n_h, d_h )\n", | |
" # compute <KX_k,QX_l>\n", | |
" D = torch.einsum(\"bshi,bthi->bhst\", [QX, KX])\n", | |
" # scaled kernel\n", | |
" r = torch.sqrt(torch.tensor(d_h).double())\n", | |
" U = torch.exp(D / r)\n", | |
" # row normalize (softmax)\n", | |
" Ut = U / torch.sum(U, axis=3)[:,:,:,None]\n", | |
" # apply kernel\n", | |
" Z = torch.einsum(\"bhst,bthi->bshi\", [Ut, VX]).reshape(n_b, p, n_h*d_h)\n", | |
" # apply final linear layer\n", | |
" return torch.einsum(\"ij,bsj->bsi\", [L, Z])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "funvuL0XkgIp" | |
}, | |
"source": [ | |
"Compare the Pytorch implementation with out own." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wIVhYALbkgIp", | |
"outputId": "1c43dc70-d814-4179-d3de-61f52e873251" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.9810291e-07\n" | |
] | |
} | |
], | |
"source": [ | |
"# using pytorch code\n", | |
"M = torch.nn.MultiheadAttention(d, n_h, batch_first=True, dropout=0.0, bias=False, device=device) # make sure to use the batch_first arg. according to your data layout\n", | |
"Y_torch,_ = M(X, X, X) # self attention\n", | |
"\n", | |
"# Retrieve the Q, K, V matrices\n", | |
"Q = M.in_proj_weight[:d, :]\n", | |
"K = M.in_proj_weight[d:2*d, :]\n", | |
"V = M.in_proj_weight[2*d:, :]\n", | |
"# final projection matrix\n", | |
"L = M.out_proj.weight\n", | |
"\n", | |
"# using our own code\n", | |
"Y = multi_head_attention(X, K, Q, V, L, n_h)\n", | |
"\n", | |
"# should be 0 ...\n", | |
"print((torch.norm(Y_torch - Y) /torch.norm(Y)).detach().cpu().numpy() )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "afbNwNKUkgIq" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3.9.12 ('base')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.12" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "1bd775ba363a980e8663c4b5e6fe16a4d2483fcdf14e1d1ee576e4bf99bce45c" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment