Skip to content

Instantly share code, notes, and snippets.

@machinaut
Created July 26, 2021 23:08
Show Gist options
  • Save machinaut/30b365d31abb4941fc838e0acb9e5db3 to your computer and use it in GitHub Desktop.
Save machinaut/30b365d31abb4941fc838e0acb9e5db3 to your computer and use it in GitHub Desktop.
Trying a bare cuda vector add against pytorch and triton
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import os\n",
"from ctypes import CDLL, c_void_p\n",
"\n",
"import torch\n",
"import triton\n",
"import triton.language as tl"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"@triton.jit\n",
"def _add( X, Y, Z, N, **meta ):\n",
" pid = tl.program_id(0)\n",
" offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])\n",
" mask = offsets < N\n",
" x = tl.load(X + offsets, mask=mask)\n",
" y = tl.load(Y + offsets, mask=mask)\n",
" z = x + y\n",
" tl.store(Z + offsets, z)\n",
"\n",
"def add(x, y):\n",
" z = torch.empty_like(x)\n",
" N = z.shape[0]\n",
" grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )\n",
" _add[grid](x, y, z, N, BLOCK=1024)\n",
" return z"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"path = os.path.join(os.environ['PWD'], 'cuda_ctypes')\n",
"vadd_so = os.path.join(path, 'vadd.so')\n",
"vadd = CDLL(vadd_so)\n",
"\n",
"def add_cuda(x, y, threads=1024):\n",
" z = torch.empty_like(x)\n",
" N = z.shape[0]\n",
" # Get pointers to the data\n",
" xp = c_void_p(x.data_ptr())\n",
" yp = c_void_p(y.data_ptr())\n",
" zp = c_void_p(z.data_ptr())\n",
" # Run the cuda kernel\n",
" vadd.vadd(xp, yp, zp, N, threads)\n",
" return z\n",
"\n",
"torch.manual_seed(0)\n",
"size = 98432\n",
"x = torch.rand(size, device='cuda', dtype=torch.float32)\n",
"y = torch.rand(size, device='cuda', dtype=torch.float32)\n",
"za = x + y\n",
"zb = add_cuda(x, y)\n",
"print(za)\n",
"print(zb)\n",
"print(f'The maximum difference between cuda and triton is ' f'{torch.max(torch.abs(za - zb))}')"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')\n",
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')\n",
"The maximum difference between cuda and triton is 0.0\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"@triton.testing.perf_report(\n",
" triton.testing.Benchmark(\n",
" x_names=['size'], # argument names to use as an x-axis for the plot\n",
" x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`\n",
" x_log=True, # x axis is logarithmic\n",
" line_arg='provider', # argument name whose value corresponds to a different line in the plot\n",
" line_vals=['triton', 'torch', 'cuda'], # possible values for `line_arg`\n",
" line_names=[\"Triton\", \"Torch\", \"CUDA\"], # label name for the lines\n",
" styles=[('blue', '-'), ('green', '-'), ('orange', '-')], # line styles\n",
" ylabel=\"GB/s\", # label name for the y-axis\n",
" plot_name=\"vector-add-performance\", # name for the plot. Used also as a file name for saving the plot.\n",
" args={} # values for function arguments not in `x_names` and `y_name`\n",
" )\n",
")\n",
"def benchmark(size, provider):\n",
" x = torch.rand(size, device='cuda', dtype=torch.float32)\n",
" y = torch.rand(size, device='cuda', dtype=torch.float32)\n",
" if provider == 'torch':\n",
" ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)\n",
" if provider == 'triton':\n",
" ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))\n",
" if provider == 'cuda':\n",
" ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_cuda(x, y))\n",
" gbps = lambda ms: 12 * size / ms * 1e-6\n",
" return gbps(ms), gbps(max_ms), gbps(min_ms)\n",
"\n",
"\n",
"benchmark.run(print_data=True, show_plots=True)"
],
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": ""
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"vector-add-performance:\n",
" size Triton Torch CUDA\n",
"0 4096.0 9.600000 9.600000 9.600000\n",
"1 8192.0 19.200000 19.200000 19.200000\n",
"2 16384.0 38.400001 38.400001 38.400001\n",
"3 32768.0 63.999998 63.999998 63.999998\n",
"4 65536.0 127.999995 127.999995 127.999995\n",
"5 131072.0 219.428568 219.428568 219.428568\n",
"6 262144.0 341.333321 341.333321 341.333321\n",
"7 524288.0 472.615390 472.615390 511.999982\n",
"8 1048576.0 614.400016 614.400016 614.400016\n",
"9 2097152.0 722.823517 722.823517 702.171410\n",
"10 4194304.0 780.190482 780.190482 768.000002\n",
"11 8388608.0 812.429770 812.429770 792.774204\n",
"12 16777216.0 833.084721 833.084721 812.429770\n",
"13 33554432.0 843.811163 842.004273 820.910214\n",
"14 67108864.0 848.362445 848.362445 824.352211\n",
"15 134217728.0 851.577704 850.656574 829.569620\n"
]
}
],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.9.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
CUDA_PATH ?= /usr/local/cuda
.PHONY: clean
vadd.so: vadd.o
nvcc -shared $^ -o $@ -lcuda
vadd.o: vadd.cu
nvcc -I $(CUDA_PATH)/include -I$(CUDA_PATH)/samples/common/inc -arch=sm_70 --compiler-options '-fPIC' $^ -c $@
clean:
rm -f *.o *.so
// For the CUDA runtime routines (prefixed with "cuda_")
// #include <cuda.h>
#include <cuda_runtime.h>
namespace
{
__global__ void _vadd(const float *A, const float *B, float *C, int n)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < n)
{
C[i] = A[i] + B[i];
}
}
}
extern "C" void vadd(const float *A, const float *B, float *C, int n, int threads)
{
const int blocks = (n + threads - 1) / threads;
_vadd<<<blocks, threads>>>(A, B, C, n);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment