Skip to content

Instantly share code, notes, and snippets.

@mitchellnw
Last active December 25, 2023 20:50
Show Gist options
  • Save mitchellnw/0b48501f1c2e5043f3d6e666b5475695 to your computer and use it in GitHub Desktop.
Save mitchellnw/0b48501f1c2e5043f3d6e666b5475695 to your computer and use it in GitHub Desktop.
relu-attention-fp32.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true,
"gpuType": "T4",
"authorship_tag": "ABX9TyPdKsPUF0jkG5Jt0Chs63+6",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/mitchellnw/0b48501f1c2e5043f3d6e666b5475695/triton-t4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"Known issues:\n",
"- Slow speeds with large number of heads\n",
"- Slow speeds when wrapped in autograd.Function"
],
"metadata": {
"id": "PceoDskjtmLq"
}
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "z64jJbxz0HtE",
"outputId": "44650383-4cbb-41df-c597-875f274e0173"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n",
"\n"
]
}
],
"source": [
"!export LC_ALL=\"en_US.UTF-8\"\n",
"!export LD_LIBRARY_PATH=\"/usr/lib64-nvidia\"\n",
"!export LIBRARY_PATH=\"/usr/local/cuda/lib64/stubs\"\n",
"!ldconfig /usr/lib64-nvidia"
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import numpy as np\n",
"import triton\n",
"import triton.language as tl"
],
"metadata": {
"id": "6AwR00GE0jHV"
},
"execution_count": 56,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Forward pass"
],
"metadata": {
"id": "RIw5lOEf1K18"
}
},
{
"cell_type": "code",
"source": [
"@triton.jit\n",
"def relu_attn_(q_ptr,\n",
" k_ptr,\n",
" v_ptr,\n",
" o_ptr,\n",
" Dh: tl.constexpr, # head dim\n",
" L: tl.constexpr, # seqlen\n",
" Nh: tl.constexpr, # num heads\n",
" B: tl.constexpr, # batchsize\n",
" sm_scale: tl.constexpr, # 1/sqrt(Dh)\n",
" relu_scale: tl.constexpr, # 1/L\n",
" is_causal: tl.constexpr,\n",
" is_squared: tl.constexpr,\n",
" BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n",
" ):\n",
" # Q, K, V is of size [B, L, Nh, Dh]\n",
" pid = tl.program_id(axis=0) # current program id\n",
" currB = (pid * BLOCK_SIZE) // (Nh * L) # current batch idx\n",
" currL = (BLOCK_SIZE * pid) % L\n",
" currNh = ((BLOCK_SIZE * pid) // L) % Nh\n",
" # Common offsets\n",
" block_start = currB*Nh*L*Dh + currL*Nh*Dh + currNh*Dh\n",
" bsz_offset = tl.arange(0, BLOCK_SIZE)\n",
" common_offset = tl.arange(0, Dh)[None, :] + bsz_offset[:, None]*(Dh*Nh)\n",
" # Always keep q in mem\n",
" q = tl.load(q_ptr + block_start + common_offset)\n",
" # Accum.\n",
" acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n",
" # Loop over seqlen in BLOCK_SIZE chunks\n",
" upper = currL + 1 if is_causal else L\n",
" for l in range(0, upper, BLOCK_SIZE):\n",
" common_kv_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n",
" k = tl.load(k_ptr + common_kv_offset)\n",
" v = tl.load(v_ptr + common_kv_offset)\n",
" qk = tl.dot(q * sm_scale, tl.trans(k)) # TODO: why is bfloat cast required\n",
" # causal masking and relu\n",
" mask = (qk >= 0)\n",
" if is_causal:\n",
" mask *= ((currL + bsz_offset)[:, None] >= (l + bsz_offset)[None, :])\n",
" qk = tl.where(mask, qk, 0.)\n",
" if is_squared:\n",
" qk *= qk\n",
" acc += tl.dot(relu_scale * qk, v) # TODO: why is bfloat cast required\n",
" tl.store(o_ptr + block_start + common_offset, acc)\n",
"\n",
"\n",
"def relu_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_causal: bool = True, is_squared: bool = False):\n",
" output = torch.empty_like(q)\n",
" B, L, Nh, Dh = q.shape\n",
" BLOCK_SIZE = min(L, 64)\n",
" grid = lambda meta: ((B * Nh * L) // BLOCK_SIZE, )\n",
" relu_attn_[grid](q, k, v, output, Dh, L, Nh, B, 1./np.sqrt(Dh), 1./L, is_causal=is_causal, is_squared=is_squared, BLOCK_SIZE=BLOCK_SIZE, num_warps=4, num_stages=1)\n",
" return output"
],
"metadata": {
"id": "p10gyE1q0wzt"
},
"execution_count": 57,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Correctness check forward pass"
],
"metadata": {
"id": "akMOJxaA1P4q"
}
},
{
"cell_type": "code",
"source": [
"B, L, Nh, Dh = 4, 128, 2, 16 # Not sure these shapes are realistic? Just a test.\n",
"Q = torch.randn(B, L, Nh, Dh).cuda()\n",
"K = torch.randn(B, L, Nh, Dh).cuda()\n",
"V = torch.randn(B, L, Nh, Dh).cuda()\n",
"\n",
"KK = K.swapaxes(1, 2)\n",
"QQ = Q.swapaxes(1, 2)\n",
"VV = V.swapaxes(1, 2)\n",
"\n",
"gt_O = torch.matmul((1./L) * torch.nn.functional.relu(torch.matmul(QQ / np.sqrt(Dh), KK.swapaxes(-2, -1))) * torch.ones(1, 1, L, L).cuda().tril(), VV).swapaxes(1, 2)\n",
"O = relu_attn(Q, K, V)\n",
"print((gt_O - O).abs().mean())"
],
"metadata": {
"id": "91ujF-Jd02A5",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "87f55313-ec13-4f54-9bb9-38c5478be89a"
},
"execution_count": 58,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(0., device='cuda:0')\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Speed check forward pass"
],
"metadata": {
"id": "tS7kBSyb1U-c"
}
},
{
"cell_type": "code",
"source": [
"%timeit torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))"
],
"metadata": {
"id": "gMZQY_1n1CRG",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "645bbf67-71bc-4359-cfed-77ac3fd2ff54"
},
"execution_count": 59,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"72.9 µs ± 4.35 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"relu_attn(Q, K, V)\n",
"%timeit relu_attn(Q, K, V)"
],
"metadata": {
"id": "OU4_xNSa1g57",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f888d5d6-2813-4cb5-e6ff-14e2c57b98ce"
},
"execution_count": 60,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"52.3 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"relu_attn(Q, K, V, is_squared=True)\n",
"%timeit relu_attn(Q, K, V, is_squared=True)"
],
"metadata": {
"id": "zORlZJdp1ji8",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "211d9c95-acc5-4cd2-b4cd-5b1241ba364d"
},
"execution_count": 61,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"60.4 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Backward pass"
],
"metadata": {
"id": "F-JEYRvB5OfS"
}
},
{
"cell_type": "code",
"source": [
"@triton.jit\n",
"def relu_attn_bwd_(q_ptr,\n",
" k_ptr,\n",
" v_ptr,\n",
" do_ptr,\n",
" dq_ptr,\n",
" dk_ptr,\n",
" dv_ptr,\n",
" Dh: tl.constexpr, # head dim\n",
" L: tl.constexpr, # seqlen\n",
" Nh: tl.constexpr, # num heads\n",
" B: tl.constexpr, # batchsize\n",
" sm_scale: tl.constexpr, # 1/sqrt(Dh)\n",
" relu_scale: tl.constexpr, # 1/L\n",
" is_causal: tl.constexpr,\n",
" is_squared: tl.constexpr,\n",
" BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n",
" ):\n",
" # Q, K, V is of size [B, L, Nh, Dh]\n",
" pid = tl.program_id(axis=0) # current program id\n",
" currB = (pid * BLOCK_SIZE) // (Nh * L) # current batch idx\n",
" currL = (BLOCK_SIZE * pid) % L\n",
" currNh = ((BLOCK_SIZE * pid) // L) % Nh\n",
" # Common offsets\n",
" block_start = currB*Nh*L*Dh + currL*Nh*Dh + currNh*Dh\n",
" bsz_offset = tl.arange(0, BLOCK_SIZE)\n",
" common_offset = tl.arange(0, Dh)[None, :] + bsz_offset[:, None]*(Dh*Nh)\n",
"\n",
" # # Accumulators\n",
" dv_acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n",
" dk_acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n",
" dq_acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n",
"\n",
" # Part 1.\n",
" k = tl.load(k_ptr + block_start + common_offset)\n",
" v = tl.load(v_ptr + block_start + common_offset)\n",
" lower = currL if is_causal else 0\n",
" for l in range(lower, L, BLOCK_SIZE):\n",
" common_sub_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n",
" q = tl.load(q_ptr + common_sub_offset)\n",
" do = tl.load(do_ptr + common_sub_offset)\n",
" qk = tl.dot(sm_scale * q, tl.trans(k))\n",
" mask = (qk >= 0)\n",
" if is_causal:\n",
" mask = mask * ((l + bsz_offset)[:, None] >= (currL + bsz_offset)[None, :])\n",
" qk = tl.where(mask, qk, 0.)\n",
" da = tl.dot(do, tl.trans(v) * relu_scale) * sm_scale\n",
" da = tl.where(mask, da, 0.)\n",
" if is_squared:\n",
" da *= qk\n",
" qk *= qk\n",
" dv_acc += relu_scale * tl.dot(tl.trans(qk), do)\n",
" dk = tl.dot(tl.trans(da), q)\n",
" dk_acc += dk\n",
" tl.store(dv_ptr + block_start + common_offset, dv_acc)\n",
" tl.store(dk_ptr + block_start + common_offset, dk_acc)\n",
"\n",
" # Part 2.\n",
" q = tl.load(q_ptr + block_start + common_offset)\n",
" do = tl.load(do_ptr + block_start + common_offset)\n",
" upper = currL + 1 if is_causal else L\n",
" for l in range(0, upper, BLOCK_SIZE):\n",
" common_sub_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n",
" k = tl.load(k_ptr + common_sub_offset)\n",
" v = tl.load(v_ptr + common_sub_offset)\n",
" qk = tl.dot(sm_scale * q, tl.trans(k))\n",
" mask = (qk >= 0)\n",
" if is_causal:\n",
" mask = mask * ((currL + bsz_offset)[:, None] >= (l + bsz_offset)[None, :])\n",
" da = tl.dot(do, tl.trans(v) * relu_scale) * sm_scale\n",
" da = tl.where(mask, da, 0.)\n",
" if is_squared:\n",
" da *= qk\n",
" dq = tl.dot(da, k)\n",
" dq_acc += dq\n",
" tl.store(dq_ptr + block_start + common_offset, dq_acc)\n",
"\n",
"\n",
"def relu_attn_bwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, do: torch.Tensor, is_causal: bool = True, is_squared: bool = False):\n",
" # We need to preallocate the output.\n",
" dq = torch.empty_like(q)\n",
" dk = torch.empty_like(k)\n",
" dv = torch.empty_like(v)\n",
"\n",
" B, L, Nh, Dh = q.shape\n",
" BLOCK_SIZE = min(L, 64)\n",
" grid = lambda meta: ((B * Nh * L) // BLOCK_SIZE,)\n",
" relu_attn_bwd_[grid](q, k, v, do, dq, dk, dv, Dh=Dh, L=L, Nh=Nh, B=B, sm_scale=1./np.sqrt(Dh), relu_scale=1./L, is_causal=is_causal, is_squared=is_squared, BLOCK_SIZE=BLOCK_SIZE, num_warps=4)\n",
" return dq, dk, dv"
],
"metadata": {
"id": "yPIKUdCg5QOm"
},
"execution_count": 62,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Correctness check backward pass"
],
"metadata": {
"id": "cDrYA2qD5WV8"
}
},
{
"cell_type": "code",
"source": [
"B, L, Nh, Dh = 4, 128, 2, 16 # Not sure these shapes are realistic? Just a test.\n",
"Q = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n",
"K = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n",
"V = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n",
"dO = torch.randn(B, L, Nh, Dh).cuda()#.bfloat16()\n",
"\n",
"## standard forward\n",
"QQ = Q.swapaxes(1, 2)\n",
"KK = K.swapaxes(1, 2)\n",
"VV = V.swapaxes(1, 2)\n",
"dOO = dO.swapaxes(1, 2)\n",
"\n",
"QKT = torch.matmul(QQ / np.sqrt(Dh), KK.swapaxes(-2, -1)) * torch.ones(1, 1, L, L).cuda().tril()#.bfloat16()\n",
"preA = torch.nn.functional.relu(QKT)\n",
"A = (1./L) * (preA)# * preA)\n",
"O = torch.matmul(A, VV)\n",
"\n",
"dVV = torch.matmul(A.swapaxes(-2, -1), dOO)\n",
"dA = torch.matmul(dOO, VV.swapaxes(-2, -1))\n",
"dQKT = (1./L) * (1./np.sqrt(Dh)) * dA #* preA\n",
"dQKT = torch.where(QKT >= 0, dQKT, 0.) * torch.ones(1, 1, L, L).cuda().tril()#.bfloat16()\n",
"dQQ = torch.matmul(dQKT, KK)\n",
"dKK = torch.matmul(dQKT.swapaxes(-2, -1), QQ)\n",
"\n",
"dQ = dQQ.swapaxes(1, 2)\n",
"dK = dKK.swapaxes(1, 2)\n",
"dV = dVV.swapaxes(1, 2)\n",
"\n",
"## triton backward pass\n",
"tdQ, tdK, tdV = relu_attn_bwd(Q, K, V, dO)"
],
"metadata": {
"id": "y0VXOJOu5QiR"
},
"execution_count": 63,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# test equality\n",
"print((dV - tdV).abs().mean())\n",
"print((dK - tdK).abs().mean())\n",
"print((dQ - tdQ).abs().mean())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P-ntsCFj5Yuy",
"outputId": "7ff1ca59-13ef-4a5f-b967-32fbc943202d"
},
"execution_count": 64,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(2.1916e-09, device='cuda:0')\n",
"tensor(0., device='cuda:0')\n",
"tensor(0., device='cuda:0')\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Speed check forward + backward pass\n"
],
"metadata": {
"id": "uH0W87NLmw_C"
}
},
{
"cell_type": "code",
"source": [
"# How long does it take with flash attention v2?\n",
"Q.requires_grad_(True)\n",
"K.requires_grad_(True)\n",
"V.requires_grad_(True)\n",
"def forward_and_backward():\n",
" loss = torch.mean(torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh)))\n",
" loss.backward()\n",
"%timeit forward_and_backward()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4xyHLJNJm0Im",
"outputId": "272c1523-9330-4165-81d5-e45434c2b92f"
},
"execution_count": 65,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"426 µs ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# How long does it take with our triton kernels\n",
"relu_attn(Q, K, V)\n",
"relu_attn_bwd(Q, K, V, dO)\n",
"def forward_and_backward():\n",
" relu_attn(Q, K, V)\n",
" relu_attn_bwd(Q, K, V, dO)\n",
"%timeit forward_and_backward()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZwhCNMzTnH8N",
"outputId": "037e98b6-aa90-4118-9692-d74c0fc2b308"
},
"execution_count": 66,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"143 µs ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Putting it together"
],
"metadata": {
"id": "SNMP6ZAZoZO5"
}
},
{
"cell_type": "code",
"source": [
"class ReluAttention(torch.autograd.Function):\n",
" @staticmethod\n",
" def forward(ctx, q, k, v, is_squared, is_causal):\n",
" ctx.save_for_backward = q, k, v, is_causal, is_squared\n",
" return relu_attn(q, k, v, is_causal=is_causal, is_squared=is_squared)\n",
"\n",
" @staticmethod\n",
" def backward(ctx, g):\n",
" q, k, v, is_causal, is_squared = ctx.save_for_backward\n",
" dq, dk, dv = relu_attn_bwd(q, k, v, g, is_causal=is_causal, is_squared=is_squared)\n",
" return dq, dk, dv, None, None\n",
"\n",
"relu_attention = ReluAttention.apply"
],
"metadata": {
"id": "xoM3IDXdobPt"
},
"execution_count": 67,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# How long does it take with flash attention v2?\n",
"Q.requires_grad_(True)\n",
"K.requires_grad_(True)\n",
"V.requires_grad_(True)\n",
"def forward_and_backward():\n",
" loss = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))\n",
" loss[0, 0, 0, 0].backward()\n",
"%timeit forward_and_backward()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "W6zzTJzgprUo",
"outputId": "86590d45-33a4-4ca8-ff0b-4dfa662e631a"
},
"execution_count": 68,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"503 µs ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Why would wrapping this take so much more time?\n",
"Q.requires_grad_(True)\n",
"K.requires_grad_(True)\n",
"V.requires_grad_(True)\n",
"def forward_and_backward():\n",
" loss = relu_attention(Q, K, V, False, True)\n",
" loss[0, 0, 0, 0].backward()\n",
"%timeit forward_and_backward()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IuytqTedpzPr",
"outputId": "9f1ac1b2-c45f-405e-c7f8-635b0db385a8"
},
"execution_count": 69,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"651 µs ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment