Skip to content

Instantly share code, notes, and snippets.

@grinisrit
Created April 28, 2021 12:52
Show Gist options
  • Save grinisrit/280e4f14b17fe5ee37e2e254700d9fd0 to your computer and use it in GitHub Desktop.
Save grinisrit/280e4f14b17fe5ee37e2e254700d9fd0 to your computer and use it in GitHub Desktop.
Using PyTorch with Numba and CUDA
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "numba_torch_cuda.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "5Mt7dgLwmnVJ"
},
"source": [
"!lscpu"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "bvobufFmg85f"
},
"source": [
"Based on a [gist](https://gist.github.com/t-vi/2f4fe23a5b473b9dceb95b163378b4d5#file-pytorch-numba-py) by [Thomas Viehmann](https://gist.github.com/t-vi)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xMMm5EHZusNL"
},
"source": [
"from numba import cuda, njit, prange\n",
"import numpy as np\n",
"import math\n",
"import torch\n",
"import ctypes\n",
"\n",
"@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')\n",
"def cu_exp_matrix_mul(A, c, d, u, v, b, n, m):\n",
" tx = cuda.threadIdx.x\n",
" ty = cuda.threadIdx.y\n",
" bx = cuda.blockIdx.x\n",
" by = cuda.blockIdx.y\n",
" bw = cuda.blockDim.x\n",
" bh = cuda.blockDim.y\n",
"\n",
" bi = tx + bx * bw\n",
" ni = ty + by * bh\n",
"\n",
" if ni >= n or bi >= b:\n",
" return\n",
" r = 0\n",
" for mi in range(m):\n",
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n",
" v[bi, ni] = r\n",
"\n",
"\n",
"@njit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')\n",
"def gnu_exp_matrix_mul(A, c, d, u, v, b, n, m):\n",
" for bi in range(b):\n",
" for ni in range(n):\n",
" r = 0\n",
" for mi in range(m):\n",
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n",
" v[bi, ni] = r \n",
"\n",
"@njit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)', parallel=True)\n",
"def omp_exp_matrix_mul(A, c, d, u, v, b, n, m):\n",
" for bi in prange(b):\n",
" for ni in prange(n):\n",
" r = 0\n",
" for mi in range(m):\n",
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n",
" v[bi, ni] = r \n",
"\n",
"def py_exp_matrix_mul(A, c, d, u, v, b, n, m):\n",
" for bi in range(b):\n",
" for ni in range(n):\n",
" r = 0\n",
" for mi in range(m):\n",
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n",
" v[bi, ni] = r \n",
"\n",
"def get_devicendarray(t):\n",
" assert t.type() == 'torch.cuda.FloatTensor'\n",
" ctx = cuda.cudadrv.devices.get_context(t.device.index)\n",
" mp = cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel()*4)\n",
" return cuda.cudadrv.devicearray.DeviceNDArray(t.size(), [i*4 for i in t.stride()], np.dtype('float32'), \n",
" gpu_data=mp, stream=torch.cuda.current_stream().cuda_stream)\n",
"\n",
"def batch_expmat_product(A, c, d, u, omp=False, py=False):\n",
" BLOCK=32\n",
" b = c.size(0)\n",
" n = A.size(0)\n",
" m = A.size(1)\n",
" assert A.dim()==2 and c.dim()==2 and d.dim()==2 and u.dim()==2, \"dimension mismatch\"\n",
" assert c.size(1)==m and d.size(0)==b and d.size(1)==n and u.size(0)==b and u.size(1)==m, \"size mismatch\"\n",
" v = u.new(d.size()).zero_()\n",
"\n",
" if A.is_cuda and c.is_cuda and d.is_cuda and u.is_cuda:\n",
" Ad,cd,dd,ud,vd = (get_devicendarray(x) for x in (A,c,d,u,v))\n",
" cu_exp_matrix_mul[((b-1)//BLOCK+1,(m-1)//BLOCK+1),(BLOCK,BLOCK)](Ad,cd,dd,ud,vd,b,n,m)\n",
" else:\n",
" Ad,cd,dd,ud,vd = (x.cpu().numpy() for x in (A,c,d,u,v))\n",
" if omp:\n",
" omp_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)\n",
" else:\n",
" if py:\n",
" py_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)\n",
" else:\n",
" gnu_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)\n",
" return v\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AtM6Z708u4lR"
},
"source": [
"b,n,m = 100,200,300\n",
"A = torch.randn(n,m)\n",
"c = torch.randn(b,m)\n",
"d = torch.randn(b,n)\n",
"u = torch.randn(b,m)\n",
"t = torch.randn(b,n)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ao8Rt64Dh-HG"
},
"source": [
"w_py = batch_expmat_product(A,c,d,u, py=True)\n",
"w_py"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7kyedvaiif7c"
},
"source": [
"w_cpu = batch_expmat_product(A,c,d,u)\n",
"torch.abs(w_py - w_cpu).mean()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HH_7mCpBirQp"
},
"source": [
"w_cpu_omp = batch_expmat_product(A,c,d,u, omp=True)\n",
"torch.abs(w_py - w_cpu_omp).mean()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oYXhz8QyvUap"
},
"source": [
"Acu,ccu,dcu,ucu = (x.cuda() for x in (A,c,d,u))\n",
"w_gpu = batch_expmat_product(Acu,ccu,dcu,ucu)\n",
"torch.abs(w_py.cuda() - w_gpu).mean()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PgR_pgnai8m7"
},
"source": [
"%timeit batch_expmat_product(A,c,d,u, py=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "b4nDpj0ejeDJ"
},
"source": [
"%timeit batch_expmat_product(A,c,d,u)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XEImcIkSjeUD"
},
"source": [
"%timeit batch_expmat_product(A,c,d,u, omp=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gBHvKyoUjgoP"
},
"source": [
"%timeit batch_expmat_product(Acu,ccu,dcu,ucu)"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment