Skip to content

Instantly share code, notes, and snippets.

@zaccharieramzi
Created June 3, 2024 08:13
Show Gist options
  • Save zaccharieramzi/10e64574e68510b901bd08ba1894e13c to your computer and use it in GitHub Desktop.
Save zaccharieramzi/10e64574e68510b901bd08ba1894e13c to your computer and use it in GitHub Desktop.
attention-scratch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zaccharieramzi/10e64574e68510b901bd08ba1894e13c/attention-scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "s_jAcr48kgIj"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"# if torch.backends.mps.is_available(): device = torch.device(\"mps\") # special for Mac"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jjO62A2ikgIh"
},
"source": [
"### Multi-head Attention from Scratch\n",
"\n",
"This notebook implements from scratch, in a step-by-step fashion, a multi-head self-attention layer, which gives the same output as the Pytorch implementation.\n",
"\n",
"References:\n",
"- Original transformer paper: [Attention is all you need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)\n",
"- PyTorch implementation is in the function [`multi_head_attention_forward`](\n",
"https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py)\n",
"\n",
"by [Zaccharie Ramzi](https://zaccharieramzi.fr/) and [Gabriel Peyré](http://www.gpeyre.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fhVrcbzwkgIk"
},
"source": [
"The data is composed of batches of $p$ points $(x_s^b)_{s=0}^{p-1}$ in $\\mathbb{R}^d$ stored in a matrix `X` of size $(n_b,p,d)$. Here $n_b$ is the number of batches, so that $b$ runs in $0 \\ldots n_b-1$. These $n_b$ batches are processed in parallel. Note that we use here the \"batch first\" format."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "54FOnOPIkgIk"
},
"outputs": [],
"source": [
"n_b = 8 # size of batch, processed in parallel\n",
"p = 80 # number of points in each points cloud\n",
"d = 12 # dimension of the points\n",
"X = torch.randn(n_b, p, d, device=device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kPNMgSPzkgIl"
},
"source": [
"Generate the parameter of the attention layer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hxP1eBy1kgIl"
},
"outputs": [],
"source": [
"K = torch.randn(d, d, device=device)\n",
"Q = torch.randn(d, d, device=device)\n",
"V = torch.randn(d, d, device=device)\n",
"L = torch.randn(d, d, device=device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tkp5_aamkgIl"
},
"source": [
"$\\newcommand{\\coloneqq}{:=}$\n",
"First the points are transformed in Keys, Queries, Values using matrices $K \\in \\mathbb{R}^{d \\times d}$, $Q \\in \\mathbb{R}^{d \\times d}$, $V \\in \\mathbb{R}^{d \\times d}$\n",
"$$\n",
" \\forall s = 0,\\ldots,p-1, \\quad\n",
" k_s^b \\coloneqq K x_i^b, \\quad\n",
" q_s^b \\coloneqq Q x_i^b, \\quad\n",
" v_s^b \\coloneqq V x_i^b.\n",
"$$\n",
"These points are stored in the arrays `KX,QX,VX` of size $(n_b,p,d)$.\n",
"\n",
"We use Einstein summation notations to compute the transform, this is very useful and should be prefered over direct array manipulation (transposition, etc)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "zlY6K5kPkgIm"
},
"outputs": [],
"source": [
"KX = torch.einsum(\"ij,bsj->bsi\", [K, X])\n",
"QX = torch.einsum(\"ij,bsj->bsi\", [Q, X])\n",
"VX = torch.einsum(\"ij,bsj->bsi\", [V, X])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4POiS_7mkgIm"
},
"source": [
"Then each of these points such as $k_s^b \\in \\mathbb{R}^{d}$ are split into $n_h$ (\"number of heads\") points $k_{s}^{b,h} \\in \\mathbb{R}^{d_h}$ where $d_h \\coloneqq d/n_h$, i.e.\n",
"$$\n",
" k_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}), \\quad\n",
" q_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}), \\quad\n",
" v_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}).\n",
"$$\n",
"These new points are still stored in the same `KX,QX,VX`, but they have size $(n_b,p,n_h,d_h)$."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uWEbCNzZkgIn"
},
"outputs": [],
"source": [
"n_h = 2 # number of heads\n",
"d_h = d // n_h # dimension of each head\n",
"KX = KX.reshape(n_b, p, n_h, d_h )\n",
"QX = QX.reshape(n_b, p, n_h, d_h )\n",
"VX = VX.reshape(n_b, p, n_h, d_h )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ew5nPSM3kgIn"
},
"source": [
"We then compute, for each head $h=0,\\ldots,n_h-1$, the inner products between the keys and queries\n",
"$$\n",
" \\forall (s,t) \\in \\{0,\\ldots,p-1\\}^2, \\quad\n",
" D_{s,t}^{b,h} \\coloneqq \\langle k_{s}^{b,h}, q_{t}^{b,h} \\rangle_{\\mathbb{R}^{d_h}}\n",
"$$\n",
"and they are stored in the matrix `D` of size $(n_b,n_h,p,p)$."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HCSTIpqUkgIn"
},
"outputs": [],
"source": [
"D = torch.einsum(\"bshi,bthi->bhst\", [QX, KX])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ukk_wbOMkgIn"
},
"source": [
"From these, one compute the attention kernel $U$ and row-normalize it to obtain $\\tilde U$ stored in `Ut` of size $(n_b,n_h,p,p)$\n",
"$$\n",
" \\tilde U_{s,t}^{b,h} \\coloneqq \\frac{U_{s,t}^{b,h}}{\\sum_{t'} U_{s,t'}^{b,h}}\n",
" \\quad\\text{where}\\quad\n",
" U_{s,t}^{b,h} \\coloneqq e^{\\frac{D_{s,t}^{b,h}}{\\sqrt{d_h}}}.\n",
"$$\n",
"The $1/\\sqrt{d_h}$ scaling is such that, at initialization, if $(K,Q)$ are Gaussian white noise with unit variance, then the entries of $\\tilde U_{s,t}^h$ have roughly the same amplitude, which is important to ease training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "GmmQmGTIkgIo"
},
"outputs": [],
"source": [
"r = torch.sqrt(torch.tensor(d_h).double()) # note that this is the per-head dimension and not the full attention dimension\n",
"U = torch.exp(D / r)\n",
"Ut = U / torch.sum(U, axis=3, keepdim=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J_iJB21YkgIo"
},
"source": [
"This kernel is then used to barycenter the values points to obtains new points\n",
"$$\n",
" \\forall s = 0,\\ldots,p-1, \\quad\n",
" z_{s}^{b,h} \\coloneqq \\sum_{t=0}^{p-1} \\tilde U_{s,t}^{b,h} v_t^b.\n",
"$$\n",
"These new points are stored in the array `Z` of size $(n_b,p,n_h,d_h)$."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q5tjoIWCkgIo"
},
"outputs": [],
"source": [
"Z = torch.einsum(\"bhst,bthi->bshi\", [Ut, VX])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fUpWQMKRkgIo"
},
"source": [
"The output of all the heads are then grouped in new points\n",
"$$\n",
" \\forall s = 0,\\ldots,p-1, \\quad\n",
" z_{s}^{b} \\coloneqq (z_{s}^{b,0},\\ldots,z_{s}^{b,n_h-1}) \\in \\mathbb{R}^d.\n",
"$$\n",
"They are still stored in the same matrix `Z` of size $(n_b,p,d)$."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HFBH7YHakgIo"
},
"outputs": [],
"source": [
"Z = Z.reshape(n_b, p, n_h*d_h)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NaCu3ZUCkgIo"
},
"source": [
"Then a final linear matrix $L \\in \\mathbb{R}^{d \\times d}$ is applied independantly to each point to obtain the output\n",
"$$\n",
" \\forall s = 0,\\ldots,p-1, \\quad\n",
" y_{s}^{b} \\coloneqq L z_{s}^{b}.\n",
"$$\n",
"These points are output by the function in an array `Y` of the same size as `X`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2Ps99ZdakgIp"
},
"outputs": [],
"source": [
"Y = torch.einsum(\"ij,bsj->bsi\", [L, Z])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zm1NG5HkgIp"
},
"source": [
"Put all this in a function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YJ3p0Ba6kgIp"
},
"outputs": [],
"source": [
"def multi_head_attention(X, K, Q, V, L, n_h):\n",
" n_b, p, d = X.shape\n",
" d_h = d // n_h # dimension of the features of each head\n",
" assert( d_h * n_h == d ), \"Embedding size needs to be divisible by heads\"\n",
" # apply the matrices K,Q,V to X, and then spread them in the different heads\n",
" KX = torch.einsum(\"ij,bsj->bsi\", [K, X]).reshape( n_b, p, n_h, d_h )\n",
" QX = torch.einsum(\"ij,bsj->bsi\", [Q, X]).reshape( n_b, p, n_h, d_h )\n",
" VX = torch.einsum(\"ij,bsj->bsi\", [V, X]).reshape( n_b, p, n_h, d_h )\n",
" # compute <KX_k,QX_l>\n",
" D = torch.einsum(\"bshi,bthi->bhst\", [QX, KX])\n",
" # scaled kernel\n",
" r = torch.sqrt(torch.tensor(d_h).double())\n",
" U = torch.exp(D / r)\n",
" # row normalize (softmax)\n",
" Ut = U / torch.sum(U, axis=3)[:,:,:,None]\n",
" # apply kernel\n",
" Z = torch.einsum(\"bhst,bthi->bshi\", [Ut, VX]).reshape(n_b, p, n_h*d_h)\n",
" # apply final linear layer\n",
" return torch.einsum(\"ij,bsj->bsi\", [L, Z])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "funvuL0XkgIp"
},
"source": [
"Compare the Pytorch implementation with out own."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wIVhYALbkgIp",
"outputId": "1c43dc70-d814-4179-d3de-61f52e873251"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.9810291e-07\n"
]
}
],
"source": [
"# using pytorch code\n",
"M = torch.nn.MultiheadAttention(d, n_h, batch_first=True, dropout=0.0, bias=False, device=device) # make sure to use the batch_first arg. according to your data layout\n",
"Y_torch,_ = M(X, X, X) # self attention\n",
"\n",
"# Retrieve the Q, K, V matrices\n",
"Q = M.in_proj_weight[:d, :]\n",
"K = M.in_proj_weight[d:2*d, :]\n",
"V = M.in_proj_weight[2*d:, :]\n",
"# final projection matrix\n",
"L = M.out_proj.weight\n",
"\n",
"# using our own code\n",
"Y = multi_head_attention(X, K, Q, V, L, n_h)\n",
"\n",
"# should be 0 ...\n",
"print((torch.norm(Y_torch - Y) /torch.norm(Y)).detach().cpu().numpy() )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "afbNwNKUkgIq"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": [],
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.12 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "1bd775ba363a980e8663c4b5e6fe16a4d2483fcdf14e1d1ee576e4bf99bce45c"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment