Skip to content

Instantly share code, notes, and snippets.

@mitchellnw
Last active December 26, 2023 04:09
Show Gist options
  • Save mitchellnw/17d529b1a5eabd38ca345e41f5002074 to your computer and use it in GitHub Desktop.
Save mitchellnw/17d529b1a5eabd38ca345e41f5002074 to your computer and use it in GitHub Desktop.
triton-a100.ipynb
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "A100",
"authorship_tag": "ABX9TyPwJzd6a2YBBAVCNYjs4W64",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/mitchellnw/17d529b1a5eabd38ca345e41f5002074/triton-a100.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"Known issues:\n",
"- Slow speeds with large number of heads\n",
"- Slow speeds when wrapped in autograd.Function\n",
"- This notebook contains only the forward pass. See https://gist.github.com/mitchellnw/0b48501f1c2e5043f3d6e666b5475695 whcih has the backward as well."
],
"metadata": {
"id": "PceoDskjtmLq"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "z64jJbxz0HtE",
"outputId": "fd027d65-13a0-421b-d9b9-b16222082341"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n",
"\n"
]
}
],
"source": [
"!export LC_ALL=\"en_US.UTF-8\"\n",
"!export LD_LIBRARY_PATH=\"/usr/lib64-nvidia\"\n",
"!export LIBRARY_PATH=\"/usr/local/cuda/lib64/stubs\"\n",
"!ldconfig /usr/lib64-nvidia"
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import numpy as np\n",
"import triton\n",
"import triton.language as tl"
],
"metadata": {
"id": "6AwR00GE0jHV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Forward pass"
],
"metadata": {
"id": "RIw5lOEf1K18"
}
},
{
"cell_type": "code",
"source": [
"@triton.jit\n",
"def relu_attn_(q_ptr,\n",
" k_ptr,\n",
" v_ptr,\n",
" o_ptr,\n",
" Dh: tl.constexpr, # head dim\n",
" L: tl.constexpr, # seqlen\n",
" Nh: tl.constexpr, # num heads\n",
" B: tl.constexpr, # batchsize\n",
" sm_scale: tl.constexpr, # 1/sqrt(Dh)\n",
" relu_scale: tl.constexpr, # 1/L\n",
" is_causal: tl.constexpr,\n",
" is_squared: tl.constexpr,\n",
" BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n",
" ):\n",
" # Q, K, V is of size [B, L, Nh, Dh]\n",
" pid = tl.program_id(axis=0) # current program id\n",
" currB = (pid * BLOCK_SIZE) // (Nh * L) # current batch idx\n",
" currL = (BLOCK_SIZE * pid) % L\n",
" currNh = ((BLOCK_SIZE * pid) // L) % Nh\n",
" # Common offsets\n",
" block_start = currB*Nh*L*Dh + currL*Nh*Dh + currNh*Dh\n",
" bsz_offset = tl.arange(0, BLOCK_SIZE)\n",
" common_offset = tl.arange(0, Dh)[None, :] + bsz_offset[:, None]*(Dh*Nh)\n",
" # Always keep q in mem\n",
" q = tl.load(q_ptr + block_start + common_offset)\n",
" # Accum.\n",
" acc = tl.zeros((BLOCK_SIZE, Dh), dtype=tl.float32)\n",
" # Loop over seqlen in BLOCK_SIZE chunks\n",
" upper = currL + 1 if is_causal else L\n",
" for l in range(0, upper, BLOCK_SIZE):\n",
" common_kv_offset = currB*Nh*L*Dh + l*Nh*Dh + currNh*Dh + common_offset\n",
" k = tl.load(k_ptr + common_kv_offset)\n",
" v = tl.load(v_ptr + common_kv_offset)\n",
" qk = tl.dot((q * sm_scale).to(tl.bfloat16), tl.trans(k)) # TODO: why is bfloat cast required\n",
" # causal masking and relu\n",
" mask = (qk >= 0)\n",
" if is_causal:\n",
" mask *= ((currL + bsz_offset)[:, None] >= (l + bsz_offset)[None, :])\n",
" qk = tl.where(mask, qk, 0.)\n",
" if is_squared:\n",
" qk *= qk\n",
" acc += tl.dot((relu_scale * qk).to(tl.bfloat16), v) # TODO: why is bfloat cast required\n",
" tl.store(o_ptr + block_start + common_offset, acc)\n",
"\n",
"\n",
"def relu_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_causal: bool = True, is_squared: bool = False):\n",
" output = torch.empty_like(q)\n",
" B, L, Nh, Dh = q.shape\n",
" BLOCK_SIZE = min(L, 64)\n",
" grid = lambda meta: ((B * Nh * L) // BLOCK_SIZE, )\n",
" relu_attn_[grid](q, k, v, output, Dh, L, Nh, B, 1./np.sqrt(Dh), 1./L, is_causal=is_causal, is_squared=is_squared, BLOCK_SIZE=BLOCK_SIZE, num_warps=4, num_stages=1)\n",
" return output"
],
"metadata": {
"id": "p10gyE1q0wzt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Correctness check forward pass"
],
"metadata": {
"id": "akMOJxaA1P4q"
}
},
{
"cell_type": "code",
"source": [
"B, L, Nh, Dh = 4, 1024, 4, 128 # Not sure these shapes are realistic? Just a test.\n",
"Q = torch.randn(B, L, Nh, Dh).cuda().bfloat16()\n",
"K = torch.randn(B, L, Nh, Dh).cuda().bfloat16()\n",
"V = torch.randn(B, L, Nh, Dh).cuda().bfloat16()\n",
"\n",
"KK = K.swapaxes(1, 2)\n",
"QQ = Q.swapaxes(1, 2)\n",
"VV = V.swapaxes(1, 2)\n",
"\n",
"gt_O = torch.matmul((1./L) * torch.nn.functional.relu(torch.matmul(QQ / np.sqrt(Dh), KK.swapaxes(-2, -1))) * torch.ones(1, 1, L, L).cuda().tril().bfloat16(), VV).swapaxes(1, 2)\n",
"O = relu_attn(Q, K, V)\n",
"print((gt_O - O).abs().mean())"
],
"metadata": {
"id": "91ujF-Jd02A5",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "301805c9-ffff-4f38-ab8f-988dcf9e3dd0"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(0., device='cuda:0', dtype=torch.bfloat16)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Speed check forward pass"
],
"metadata": {
"id": "tS7kBSyb1U-c"
}
},
{
"cell_type": "code",
"source": [
"%timeit torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))"
],
"metadata": {
"id": "gMZQY_1n1CRG",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "71b2ac0d-b24a-4c27-f5ad-bef5453b853f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"171 µs ± 1.32 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"relu_attn(Q, K, V)\n",
"%timeit relu_attn(Q, K, V)"
],
"metadata": {
"id": "OU4_xNSa1g57",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f251da95-c8d3-4a4c-e453-a8847bf9b800"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"72.4 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"relu_attn(Q, K, V, is_squared=True)\n",
"%timeit relu_attn(Q, K, V, is_squared=True)"
],
"metadata": {
"id": "zORlZJdp1ji8",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "43f06bf1-3b89-4859-c53a-bb88a9aeafe5"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"73 µs ± 8.88 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Speed as a function of seqlen"
],
"metadata": {
"id": "C4ISt9NO09jC"
}
},
{
"cell_type": "code",
"source": [
"seqlens = [256, 512, 1024, 2048]\n",
"flash_attn_speed = []\n",
"relu_attn_speed = []\n",
"\n",
"for seqlen in seqlens:\n",
" B, L, Nh, Dh = 4, seqlen, 4, 128\n",
" Q = torch.randn(B, L, Nh, Dh).cuda().bfloat16()\n",
" K = torch.randn(B, L, Nh, Dh).cuda().bfloat16()\n",
" V = torch.randn(B, L, Nh, Dh).cuda().bfloat16()\n",
"\n",
" torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))\n",
" flash_attn_result = %timeit -o torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1./np.sqrt(Dh))\n",
" flash_attn_speed.append(flash_attn_result.average)\n",
"\n",
" relu_attn(Q, K, V)\n",
" relu_attn_result = %timeit -o relu_attn(Q, K, V)\n",
" relu_attn_speed.append(relu_attn_result.average)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oUEUpIHD0_lh",
"outputId": "a59e49f8-0d90-4312-ca74-5937f4015da1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"46.9 µs ± 259 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"55.5 µs ± 504 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"87.3 µs ± 16.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"56.2 µs ± 490 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"171 µs ± 2.42 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"72.5 µs ± 61 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"342 µs ± 3.02 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"179 µs ± 37.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(seqlens, flash_attn_speed, label='Flash Attn v2', marker='s')\n",
"plt.plot(seqlens, relu_attn_speed, label='Relu Attn', marker='o')\n",
"plt.grid()\n",
"plt.xlabel('Seqlen')\n",
"plt.ylabel('Time')\n",
"plt.xscale('log')\n",
"plt.ylim(0)\n",
"plt.legend()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 472
},
"id": "0nwBTzWG1QPe",
"outputId": "18d1adbf-4fba-4e40-9bec-72929755ab22"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7de2d1105870>"
]
},
"metadata": {},
"execution_count": 32
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment