Last active
December 25, 2023 20:50
-
-
Save mitchellnw/0b48501f1c2e5043f3d6e666b5475695 to your computer and use it in GitHub Desktop.
relu-attention-fp32.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"toc_visible": true, | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyPdKsPUF0jkG5Jt0Chs63+6", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/mitchellnw/0b48501f1c2e5043f3d6e666b5475695/triton-t4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Known issues:\n", | |
"- Slow speeds with large number of heads\n", | |
"- Slow speeds when wrapped in autograd.Function" | |
], | |
"metadata": { | |
"id": "PceoDskjtmLq" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "z64jJbxz0HtE", | |
"outputId": "44650383-4cbb-41df-c597-875f274e0173" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n", | |
"\n", | |
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n", | |
"\n", | |
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n", | |
"\n", | |
"/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n", | |
"\n", | |
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n", | |
"\n", | |
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"!export LC_ALL=\"en_US.UTF-8\"\n", | |
"!export LD_LIBRARY_PATH=\"/usr/lib64-nvidia\"\n", | |
"!export LIBRARY_PATH=\"/usr/local/cuda/lib64/stubs\"\n", | |
"!ldconfig /usr/lib64-nvidia" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import numpy as np\n", | |
"import triton\n", | |
"import triton.language as tl" | |
], | |
"metadata": { | |
"id": "6AwR00GE0jHV" | |
}, | |
"execution_count": 56, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Forward pass" | |
], | |
"metadata": { | |
"id": "RIw5lOEf1K18" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@triton.jit\n", | |
"def relu_attn_(q_ptr,\n", | |
" k_ptr,\n", | |
" v_ptr,\n", | |
" o_ptr,\n", | |
" Dh: tl.constexpr, # head dim\n", | |
" L: tl.constexpr, # seqlen\n", | |
" Nh: tl.constexpr, # num heads\n", | |
" B: tl.constexpr, # batchsize\n", | |
" sm_scale: tl.constexpr, # 1/sqrt(Dh)\n", | |
" relu_scale: tl.constexpr, # 1/L\n", | |
" is_causal: tl.constexpr,\n", | |
" is_squared: tl.constexpr,\n", | |
" BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n", | |
" ):\n", | |
" # Q, K, V is of size [B, L, Nh, Dh]\n", | |
" pid = tl.program_id(axis=0) # current program id\n", | |
" currB = (pid * BLOCK_SIZE) // (Nh * L) # current batch idx\n", | |
" currL = (BLOCK_SIZE * pid) % L\n", | |
" currNh = ((BLOCK_SIZE * pid) // L) % Nh\n", | |
" # Common offsets\n", | |
" block_start = currB*Nh*L*Dh + currL*Nh*Dh + currNh*Dh\n", | |
" bsz_offset = tl.arange(0, BLOCK_SIZE)\n", | |
" common_offset = tl.arange(0, Dh)[None, :] + bsz_offset[:, None]*(Dh*Nh)\n", | |
" # Always keep q in mem\n", | |
" q = tl.load(q_ptr + block_start + common_offset)\n", | |
" # Accum.\n", | |
" acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n", | |
" # Loop over seqlen in BLOCK_SIZE chunks\n", | |
" upper = currL + 1 if is_causal else L\n", | |
" for l in range(0, upper, BLOCK_SIZE):\n", | |
" common_kv_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n", | |
" k = tl.load(k_ptr + common_kv_offset)\n", | |
" v = tl.load(v_ptr + common_kv_offset)\n", | |
" qk = tl.dot(q * sm_scale, tl.trans(k)) # TODO: why is bfloat cast required\n", | |
" # causal masking and relu\n", | |
" mask = (qk >= 0)\n", | |
" if is_causal:\n", | |
" mask *= ((currL + bsz_offset)[:, None] >= (l + bsz_offset)[None, :])\n", | |
" qk = tl.where(mask, qk, 0.)\n", | |
" if is_squared:\n", | |
" qk *= qk\n", | |
" acc += tl.dot(relu_scale * qk, v) # TODO: why is bfloat cast required\n", | |
" tl.store(o_ptr + block_start + common_offset, acc)\n", | |
"\n", | |
"\n", | |
"def relu_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_causal: bool = True, is_squared: bool = False):\n", | |
" output = torch.empty_like(q)\n", | |
" B, L, Nh, Dh = q.shape\n", | |
" BLOCK_SIZE = min(L, 64)\n", | |
" grid = lambda meta: ((B * Nh * L) // BLOCK_SIZE, )\n", | |
" relu_attn_[grid](q, k, v, output, Dh, L, Nh, B, 1./np.sqrt(Dh), 1./L, is_causal=is_causal, is_squared=is_squared, BLOCK_SIZE=BLOCK_SIZE, num_warps=4, num_stages=1)\n", | |
" return output" | |
], | |
"metadata": { | |
"id": "p10gyE1q0wzt" | |
}, | |
"execution_count": 57, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Correctness check forward pass" | |
], | |
"metadata": { | |
"id": "akMOJxaA1P4q" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"B, L, Nh, Dh = 4, 128, 2, 16 # Not sure these shapes are realistic? Just a test.\n", | |
"Q = torch.randn(B, L, Nh, Dh).cuda()\n", | |
"K = torch.randn(B, L, Nh, Dh).cuda()\n", | |
"V = torch.randn(B, L, Nh, Dh).cuda()\n", | |
"\n", | |
"KK = K.swapaxes(1, 2)\n", | |
"QQ = Q.swapaxes(1, 2)\n", | |
"VV = V.swapaxes(1, 2)\n", | |
"\n", | |
"gt_O = torch.matmul((1./L) * torch.nn.functional.relu(torch.matmul(QQ / np.sqrt(Dh), KK.swapaxes(-2, -1))) * torch.ones(1, 1, L, L).cuda().tril(), VV).swapaxes(1, 2)\n", | |
"O = relu_attn(Q, K, V)\n", | |
"print((gt_O - O).abs().mean())" | |
], | |
"metadata": { | |
"id": "91ujF-Jd02A5", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "87f55313-ec13-4f54-9bb9-38c5478be89a" | |
}, | |
"execution_count": 58, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"tensor(0., device='cuda:0')\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Speed check forward pass" | |
], | |
"metadata": { | |
"id": "tS7kBSyb1U-c" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%timeit torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))" | |
], | |
"metadata": { | |
"id": "gMZQY_1n1CRG", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "645bbf67-71bc-4359-cfed-77ac3fd2ff54" | |
}, | |
"execution_count": 59, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"72.9 µs ± 4.35 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"relu_attn(Q, K, V)\n", | |
"%timeit relu_attn(Q, K, V)" | |
], | |
"metadata": { | |
"id": "OU4_xNSa1g57", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "f888d5d6-2813-4cb5-e6ff-14e2c57b98ce" | |
}, | |
"execution_count": 60, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"52.3 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"relu_attn(Q, K, V, is_squared=True)\n", | |
"%timeit relu_attn(Q, K, V, is_squared=True)" | |
], | |
"metadata": { | |
"id": "zORlZJdp1ji8", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "211d9c95-acc5-4cd2-b4cd-5b1241ba364d" | |
}, | |
"execution_count": 61, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"60.4 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Backward pass" | |
], | |
"metadata": { | |
"id": "F-JEYRvB5OfS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@triton.jit\n", | |
"def relu_attn_bwd_(q_ptr,\n", | |
" k_ptr,\n", | |
" v_ptr,\n", | |
" do_ptr,\n", | |
" dq_ptr,\n", | |
" dk_ptr,\n", | |
" dv_ptr,\n", | |
" Dh: tl.constexpr, # head dim\n", | |
" L: tl.constexpr, # seqlen\n", | |
" Nh: tl.constexpr, # num heads\n", | |
" B: tl.constexpr, # batchsize\n", | |
" sm_scale: tl.constexpr, # 1/sqrt(Dh)\n", | |
" relu_scale: tl.constexpr, # 1/L\n", | |
" is_causal: tl.constexpr,\n", | |
" is_squared: tl.constexpr,\n", | |
" BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n", | |
" ):\n", | |
" # Q, K, V is of size [B, L, Nh, Dh]\n", | |
" pid = tl.program_id(axis=0) # current program id\n", | |
" currB = (pid * BLOCK_SIZE) // (Nh * L) # current batch idx\n", | |
" currL = (BLOCK_SIZE * pid) % L\n", | |
" currNh = ((BLOCK_SIZE * pid) // L) % Nh\n", | |
" # Common offsets\n", | |
" block_start = currB*Nh*L*Dh + currL*Nh*Dh + currNh*Dh\n", | |
" bsz_offset = tl.arange(0, BLOCK_SIZE)\n", | |
" common_offset = tl.arange(0, Dh)[None, :] + bsz_offset[:, None]*(Dh*Nh)\n", | |
"\n", | |
" # # Accumulators\n", | |
" dv_acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n", | |
" dk_acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n", | |
" dq_acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n", | |
"\n", | |
" # Part 1.\n", | |
" k = tl.load(k_ptr + block_start + common_offset)\n", | |
" v = tl.load(v_ptr + block_start + common_offset)\n", | |
" lower = currL if is_causal else 0\n", | |
" for l in range(lower, L, BLOCK_SIZE):\n", | |
" common_sub_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n", | |
" q = tl.load(q_ptr + common_sub_offset)\n", | |
" do = tl.load(do_ptr + common_sub_offset)\n", | |
" qk = tl.dot(sm_scale * q, tl.trans(k))\n", | |
" mask = (qk >= 0)\n", | |
" if is_causal:\n", | |
" mask = mask * ((l + bsz_offset)[:, None] >= (currL + bsz_offset)[None, :])\n", | |
" qk = tl.where(mask, qk, 0.)\n", | |
" da = tl.dot(do, tl.trans(v) * relu_scale) * sm_scale\n", | |
" da = tl.where(mask, da, 0.)\n", | |
" if is_squared:\n", | |
" da *= qk\n", | |
" qk *= qk\n", | |
" dv_acc += relu_scale * tl.dot(tl.trans(qk), do)\n", | |
" dk = tl.dot(tl.trans(da), q)\n", | |
" dk_acc += dk\n", | |
" tl.store(dv_ptr + block_start + common_offset, dv_acc)\n", | |
" tl.store(dk_ptr + block_start + common_offset, dk_acc)\n", | |
"\n", | |
" # Part 2.\n", | |
" q = tl.load(q_ptr + block_start + common_offset)\n", | |
" do = tl.load(do_ptr + block_start + common_offset)\n", | |
" upper = currL + 1 if is_causal else L\n", | |
" for l in range(0, upper, BLOCK_SIZE):\n", | |
" common_sub_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n", | |
" k = tl.load(k_ptr + common_sub_offset)\n", | |
" v = tl.load(v_ptr + common_sub_offset)\n", | |
" qk = tl.dot(sm_scale * q, tl.trans(k))\n", | |
" mask = (qk >= 0)\n", | |
" if is_causal:\n", | |
" mask = mask * ((currL + bsz_offset)[:, None] >= (l + bsz_offset)[None, :])\n", | |
" da = tl.dot(do, tl.trans(v) * relu_scale) * sm_scale\n", | |
" da = tl.where(mask, da, 0.)\n", | |
" if is_squared:\n", | |
" da *= qk\n", | |
" dq = tl.dot(da, k)\n", | |
" dq_acc += dq\n", | |
" tl.store(dq_ptr + block_start + common_offset, dq_acc)\n", | |
"\n", | |
"\n", | |
"def relu_attn_bwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, do: torch.Tensor, is_causal: bool = True, is_squared: bool = False):\n", | |
" # We need to preallocate the output.\n", | |
" dq = torch.empty_like(q)\n", | |
" dk = torch.empty_like(k)\n", | |
" dv = torch.empty_like(v)\n", | |
"\n", | |
" B, L, Nh, Dh = q.shape\n", | |
" BLOCK_SIZE = min(L, 64)\n", | |
" grid = lambda meta: ((B * Nh * L) // BLOCK_SIZE,)\n", | |
" relu_attn_bwd_[grid](q, k, v, do, dq, dk, dv, Dh=Dh, L=L, Nh=Nh, B=B, sm_scale=1./np.sqrt(Dh), relu_scale=1./L, is_causal=is_causal, is_squared=is_squared, BLOCK_SIZE=BLOCK_SIZE, num_warps=4)\n", | |
" return dq, dk, dv" | |
], | |
"metadata": { | |
"id": "yPIKUdCg5QOm" | |
}, | |
"execution_count": 62, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Correctness check backward pass" | |
], | |
"metadata": { | |
"id": "cDrYA2qD5WV8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"B, L, Nh, Dh = 4, 128, 2, 16 # Not sure these shapes are realistic? Just a test.\n", | |
"Q = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n", | |
"K = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n", | |
"V = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n", | |
"dO = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n", | |
"\n", | |
"## standard forward\n", | |
"QQ = Q.swapaxes(1, 2)\n", | |
"KK = K.swapaxes(1, 2)\n", | |
"VV = V.swapaxes(1, 2)\n", | |
"dOO = dO.swapaxes(1, 2)\n", | |
"\n", | |
"QKT = torch.matmul(QQ / np.sqrt(Dh), KK.swapaxes(-2, -1)) * torch.ones(1, 1, L, L).cuda().tril()#.bfloat16()\n", | |
"preA = torch.nn.functional.relu(QKT)\n", | |
"A = (1./L) * (preA)# * preA)\n", | |
"O = torch.matmul(A, VV)\n", | |
"\n", | |
"dVV = torch.matmul(A.swapaxes(-2, -1), dOO)\n", | |
"dA = torch.matmul(dOO, VV.swapaxes(-2, -1))\n", | |
"dQKT = (1./L) * (1./np.sqrt(Dh)) * dA #* preA\n", | |
"dQKT = torch.where(QKT >= 0, dQKT, 0.) * torch.ones(1, 1, L, L).cuda().tril()#.bfloat16()\n", | |
"dQQ = torch.matmul(dQKT, KK)\n", | |
"dKK = torch.matmul(dQKT.swapaxes(-2, -1), QQ)\n", | |
"\n", | |
"dQ = dQQ.swapaxes(1, 2)\n", | |
"dK = dKK.swapaxes(1, 2)\n", | |
"dV = dVV.swapaxes(1, 2)\n", | |
"\n", | |
"## triton backward pass\n", | |
"tdQ, tdK, tdV = relu_attn_bwd(Q, K, V, dO)" | |
], | |
"metadata": { | |
"id": "y0VXOJOu5QiR" | |
}, | |
"execution_count": 63, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# test equality\n", | |
"print((dV - tdV).abs().mean())\n", | |
"print((dK - tdK).abs().mean())\n", | |
"print((dQ - tdQ).abs().mean())" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "P-ntsCFj5Yuy", | |
"outputId": "7ff1ca59-13ef-4a5f-b967-32fbc943202d" | |
}, | |
"execution_count": 64, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"tensor(2.1916e-09, device='cuda:0')\n", | |
"tensor(0., device='cuda:0')\n", | |
"tensor(0., device='cuda:0')\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Speed check forward + backward pass\n" | |
], | |
"metadata": { | |
"id": "uH0W87NLmw_C" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# How long does it take with flash attention v2?\n", | |
"Q.requires_grad_(True)\n", | |
"K.requires_grad_(True)\n", | |
"V.requires_grad_(True)\n", | |
"def forward_and_backward():\n", | |
" loss = torch.mean(torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh)))\n", | |
" loss.backward()\n", | |
"%timeit forward_and_backward()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4xyHLJNJm0Im", | |
"outputId": "272c1523-9330-4165-81d5-e45434c2b92f" | |
}, | |
"execution_count": 65, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"426 µs ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# How long does it take with our triton kernels\n", | |
"relu_attn(Q, K, V)\n", | |
"relu_attn_bwd(Q, K, V, dO)\n", | |
"def forward_and_backward():\n", | |
" relu_attn(Q, K, V)\n", | |
" relu_attn_bwd(Q, K, V, dO)\n", | |
"%timeit forward_and_backward()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ZwhCNMzTnH8N", | |
"outputId": "037e98b6-aa90-4118-9692-d74c0fc2b308" | |
}, | |
"execution_count": 66, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"143 µs ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Putting it together" | |
], | |
"metadata": { | |
"id": "SNMP6ZAZoZO5" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ReluAttention(torch.autograd.Function):\n", | |
" @staticmethod\n", | |
" def forward(ctx, q, k, v, is_squared, is_causal):\n", | |
" ctx.save_for_backward = q, k, v, is_causal, is_squared\n", | |
" return relu_attn(q, k, v, is_causal=is_causal, is_squared=is_squared)\n", | |
"\n", | |
" @staticmethod\n", | |
" def backward(ctx, g):\n", | |
" q, k, v, is_causal, is_squared = ctx.save_for_backward\n", | |
" dq, dk, dv = relu_attn_bwd(q, k, v, g, is_causal=is_causal, is_squared=is_squared)\n", | |
" return dq, dk, dv, None, None\n", | |
"\n", | |
"relu_attention = ReluAttention.apply" | |
], | |
"metadata": { | |
"id": "xoM3IDXdobPt" | |
}, | |
"execution_count": 67, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# How long does it take with flash attention v2?\n", | |
"Q.requires_grad_(True)\n", | |
"K.requires_grad_(True)\n", | |
"V.requires_grad_(True)\n", | |
"def forward_and_backward():\n", | |
" loss = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))\n", | |
" loss[0, 0, 0, 0].backward()\n", | |
"%timeit forward_and_backward()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "W6zzTJzgprUo", | |
"outputId": "86590d45-33a4-4ca8-ff0b-4dfa662e631a" | |
}, | |
"execution_count": 68, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"503 µs ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Why would wrapping this take so much more time?\n", | |
"Q.requires_grad_(True)\n", | |
"K.requires_grad_(True)\n", | |
"V.requires_grad_(True)\n", | |
"def forward_and_backward():\n", | |
" loss = relu_attention(Q, K, V, False, True)\n", | |
" loss[0, 0, 0, 0].backward()\n", | |
"%timeit forward_and_backward()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "IuytqTedpzPr", | |
"outputId": "9f1ac1b2-c45f-405e-c7f8-635b0db385a8" | |
}, | |
"execution_count": 69, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"651 µs ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment