Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save chao-ji/7b9a1a1eeab972302eef8d55f2a71e18 to your computer and use it in GitHub Desktop.
Save chao-ji/7b9a1a1eeab972302eef8d55f2a71e18 to your computer and use it in GitHub Desktop.
Learning Triton: Strides and Offsets.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyMHdLXnRwnK8PcYblJNVI9w",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/chao-ji/7b9a1a1eeab972302eef8d55f2a71e18/learning-triton-strides-and-offsets.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"**Transpose a N-dimensinoal tensor**\n",
"\n",
"In this short colab script, I'm trying to explain how to compute the memory offsets of a multi-dimensional array in terms of axis **strides** and indices.\n",
"\n",
"Suppose the elements of a 2d tensor with shape `[2, 3]` is stored contiguously in row-major format (which is usually the case in most scenarios), then the layout of their memory addresses will be like\n",
"\n",
"```\n",
"2d indexing: a[0, 0], a[0, 1], a[0, 2], a[1, 0], a[1, 1], a[1, 2]\n",
"pointers: a_ptr, a_ptr + 1, a_ptr + 2, a_ptr + 3, a_ptr + 4, a_ptr + 5,\n",
"\n",
"\n",
"a_ptr := a, that is `a_ptr` points to the first element of `a`\n",
"```\n",
"\n",
"Note that the memory stride for the 0th axis (i.e. height) is 3 --- for example, you need to increment `a_ptr` by 3 positions to go from `a[0, 0]` to `a[1, 0]`. Likewise, the stride for the 1st axis (i.e. width) is 1.\n",
"\n",
"\n",
"Then the memory offsets for all the elements rearranged in 2d layout can be calculated by first scaling the list of indices of the 0th dimension by its stride, and then offset by the list of indices of the 1st dimension (also scaled by its stride, which is 1):\n",
"\n",
"```\n",
" range(0, 2)[:, None] * 3 + range(0, 3) * 1 =\n",
" [\n",
" [0, 1, 2],\n",
" [3, 4, 5],\n",
" ]\n",
"```\n",
"\n",
"To transpose, we simply make width the 0th dimension and height the 1st dimension (Note that their strides DO NOT change):\n",
"\n",
"```\n",
" range(0, 3)[:, None] * 1 + range(0, 2) * 3 =\n",
" [\n",
" [0, 3],\n",
" [1, 4],\n",
" [2, 5],\n",
" ]\n",
"\n",
"```"
],
"metadata": {
"id": "u7JXiTB4u-yk"
}
},
{
"cell_type": "markdown",
"source": [
"Below are simple functions written in [triton](https://triton-lang.org/main/index.html) (a python-like language and compiler for GPU programming) for tranposing a 2d and 3d tensors."
],
"metadata": {
"id": "XucSJBhQ6Lu0"
}
},
{
"cell_type": "code",
"source": [
"import triton\n",
"import triton.language as tl\n",
"import torch"
],
"metadata": {
"id": "yFM5ACGsSUq0"
},
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@triton.jit\n",
"def transpose_2d(IN, OUT,\n",
" in_size0: tl.constexpr, in_size1: tl.constexpr,\n",
" in_stride0: tl.constexpr, in_stride1: tl.constexpr,\n",
" out_size0: tl.constexpr, out_size1: tl.constexpr,\n",
" out_stride0: tl.constexpr, out_stride1: tl.constexpr,\n",
" ):\n",
" pid = tl.program_id(0)\n",
"\n",
" in_offs = tl.arange(0, in_size0)[:, None] * in_stride0 + tl.arange(0, in_size1) * in_stride1\n",
" out_offs = tl.arange(0, out_size0)[:, None] * out_stride0 + tl.arange(0, out_size1) * out_stride1\n",
" data = tl.load(IN + in_offs)\n",
" tl.store(OUT + out_offs, data)"
],
"metadata": {
"id": "Mpr1WqJhBSvo"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"a = torch.randn(2, 4).cuda()\n",
"out = torch.empty(4, 2).cuda()\n",
"\n",
"grid = (1,)\n",
"transpose_2d[grid](a, out, *a.shape[::-1], *a.stride()[::-1], *out.shape, *out.stride())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Q_cdwzP1B1Lo",
"outputId": "1dff7ff6-e021-4f84-d081-bf2d85541a72"
},
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<triton.compiler.compiler.CompiledKernel at 0x7a7d7aa27040>"
]
},
"metadata": {},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"source": [
"a"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p8T7mt4JERrd",
"outputId": "f5608dfe-7e5f-443c-ee52-642fda37255f"
},
"execution_count": 33,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 1.2764, -1.4149, 1.9301, -0.5038],\n",
" [-1.4325, -1.0734, 0.2115, -0.0343]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 33
}
]
},
{
"cell_type": "code",
"source": [
"out"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "173qMDUpEUkW",
"outputId": "d9335e02-77bc-491b-aeb8-6d368be6af16"
},
"execution_count": 34,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 1.2764, -1.4325],\n",
" [-1.4149, -1.0734],\n",
" [ 1.9301, 0.2115],\n",
" [-0.5038, -0.0343]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 34
}
]
},
{
"cell_type": "code",
"source": [
"assert"
],
"metadata": {
"id": "aRLpF7csu8_Z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@triton.jit\n",
"def transpose_3d(IN, OUT,\n",
" in_size0: tl.constexpr, in_size1: tl.constexpr, in_size2: tl.constexpr,\n",
" in_stride0: tl.constexpr, in_stride1: tl.constexpr, in_stride2: tl.constexpr,\n",
" out_size0: tl.constexpr, out_size1: tl.constexpr, out_size2: tl.constexpr,\n",
" out_stride0: tl.constexpr, out_stride1: tl.constexpr, out_stride2: tl.constexpr,\n",
" ):\n",
" pid = tl.program_id(0)\n",
"\n",
" in_offs = tl.arange(0, in_size0)[:, None, None] * in_stride0 + tl.arange(0, in_size1)[:, None] * in_stride1 + tl.arange(0, in_size2) * in_stride2\n",
" out_offs = tl.arange(0, out_size0)[:, None, None] * out_stride0 + tl.arange(0, out_size1)[:, None] * out_stride1 + tl.arange(0, out_size2) * out_stride2\n",
" data = tl.load(IN + in_offs)\n",
" tl.store(OUT + out_offs, data)"
],
"metadata": {
"id": "rnuBFJcMEVnU"
},
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"source": [
"a = torch.randn(2, 4, 8).cuda()\n",
"out = torch.empty(8, 2, 4).cuda()\n",
"\n",
"perm_indices = [2, 0, 1]\n",
"\n",
"perm_shape = [a.shape[i] for i in perm_indices]\n",
"perm_stride = [a.stride()[i] for i in perm_indices]\n",
"\n",
"grid = (1,)\n",
"transpose_3d[grid](a, out, *perm_shape, *perm_stride, *out.shape, *out.stride())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0c30vMCGsCBi",
"outputId": "ab096a2e-10e5-480c-e1c2-203b348d264f"
},
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<triton.compiler.compiler.CompiledKernel at 0x7a7c64237e20>"
]
},
"metadata": {},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"source": [
"assert (a.permute(*perm_indices) == out).all()"
],
"metadata": {
"id": "9e5sWBqLspSa"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "tUOtmmTzsB7z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "LQwuDV9jsB4z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "hJyOUAAfsB1w"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "3KuWaus-sBzA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "eVO8wbnosBwS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Ylj7gLmQsBv5"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment