-
-
Save grinisrit/280e4f14b17fe5ee37e2e254700d9fd0 to your computer and use it in GitHub Desktop.
Using PyTorch with Numba and CUDA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "numba_torch_cuda.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5Mt7dgLwmnVJ" | |
}, | |
"source": [ | |
"!lscpu" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "bvobufFmg85f" | |
}, | |
"source": [ | |
"Based on a [gist](https://gist.github.com/t-vi/2f4fe23a5b473b9dceb95b163378b4d5#file-pytorch-numba-py) by [Thomas Viehmann](https://gist.github.com/t-vi)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xMMm5EHZusNL" | |
}, | |
"source": [ | |
"from numba import cuda, njit, prange\n", | |
"import numpy as np\n", | |
"import math\n", | |
"import torch\n", | |
"import ctypes\n", | |
"\n", | |
"@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')\n", | |
"def cu_exp_matrix_mul(A, c, d, u, v, b, n, m):\n", | |
" tx = cuda.threadIdx.x\n", | |
" ty = cuda.threadIdx.y\n", | |
" bx = cuda.blockIdx.x\n", | |
" by = cuda.blockIdx.y\n", | |
" bw = cuda.blockDim.x\n", | |
" bh = cuda.blockDim.y\n", | |
"\n", | |
" bi = tx + bx * bw\n", | |
" ni = ty + by * bh\n", | |
"\n", | |
" if ni >= n or bi >= b:\n", | |
" return\n", | |
" r = 0\n", | |
" for mi in range(m):\n", | |
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n", | |
" v[bi, ni] = r\n", | |
"\n", | |
"\n", | |
"@njit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')\n", | |
"def gnu_exp_matrix_mul(A, c, d, u, v, b, n, m):\n", | |
" for bi in range(b):\n", | |
" for ni in range(n):\n", | |
" r = 0\n", | |
" for mi in range(m):\n", | |
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n", | |
" v[bi, ni] = r \n", | |
"\n", | |
"@njit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)', parallel=True)\n", | |
"def omp_exp_matrix_mul(A, c, d, u, v, b, n, m):\n", | |
" for bi in prange(b):\n", | |
" for ni in prange(n):\n", | |
" r = 0\n", | |
" for mi in range(m):\n", | |
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n", | |
" v[bi, ni] = r \n", | |
"\n", | |
"def py_exp_matrix_mul(A, c, d, u, v, b, n, m):\n", | |
" for bi in range(b):\n", | |
" for ni in range(n):\n", | |
" r = 0\n", | |
" for mi in range(m):\n", | |
" r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]\n", | |
" v[bi, ni] = r \n", | |
"\n", | |
"def get_devicendarray(t):\n", | |
" assert t.type() == 'torch.cuda.FloatTensor'\n", | |
" ctx = cuda.cudadrv.devices.get_context(t.device.index)\n", | |
" mp = cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel()*4)\n", | |
" return cuda.cudadrv.devicearray.DeviceNDArray(t.size(), [i*4 for i in t.stride()], np.dtype('float32'), \n", | |
" gpu_data=mp, stream=torch.cuda.current_stream().cuda_stream)\n", | |
"\n", | |
"def batch_expmat_product(A, c, d, u, omp=False, py=False):\n", | |
" BLOCK=32\n", | |
" b = c.size(0)\n", | |
" n = A.size(0)\n", | |
" m = A.size(1)\n", | |
" assert A.dim()==2 and c.dim()==2 and d.dim()==2 and u.dim()==2, \"dimension mismatch\"\n", | |
" assert c.size(1)==m and d.size(0)==b and d.size(1)==n and u.size(0)==b and u.size(1)==m, \"size mismatch\"\n", | |
" v = u.new(d.size()).zero_()\n", | |
"\n", | |
" if A.is_cuda and c.is_cuda and d.is_cuda and u.is_cuda:\n", | |
" Ad,cd,dd,ud,vd = (get_devicendarray(x) for x in (A,c,d,u,v))\n", | |
" cu_exp_matrix_mul[((b-1)//BLOCK+1,(m-1)//BLOCK+1),(BLOCK,BLOCK)](Ad,cd,dd,ud,vd,b,n,m)\n", | |
" else:\n", | |
" Ad,cd,dd,ud,vd = (x.cpu().numpy() for x in (A,c,d,u,v))\n", | |
" if omp:\n", | |
" omp_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)\n", | |
" else:\n", | |
" if py:\n", | |
" py_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)\n", | |
" else:\n", | |
" gnu_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)\n", | |
" return v\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AtM6Z708u4lR" | |
}, | |
"source": [ | |
"b,n,m = 100,200,300\n", | |
"A = torch.randn(n,m)\n", | |
"c = torch.randn(b,m)\n", | |
"d = torch.randn(b,n)\n", | |
"u = torch.randn(b,m)\n", | |
"t = torch.randn(b,n)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ao8Rt64Dh-HG" | |
}, | |
"source": [ | |
"w_py = batch_expmat_product(A,c,d,u, py=True)\n", | |
"w_py" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7kyedvaiif7c" | |
}, | |
"source": [ | |
"w_cpu = batch_expmat_product(A,c,d,u)\n", | |
"torch.abs(w_py - w_cpu).mean()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HH_7mCpBirQp" | |
}, | |
"source": [ | |
"w_cpu_omp = batch_expmat_product(A,c,d,u, omp=True)\n", | |
"torch.abs(w_py - w_cpu_omp).mean()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "oYXhz8QyvUap" | |
}, | |
"source": [ | |
"Acu,ccu,dcu,ucu = (x.cuda() for x in (A,c,d,u))\n", | |
"w_gpu = batch_expmat_product(Acu,ccu,dcu,ucu)\n", | |
"torch.abs(w_py.cuda() - w_gpu).mean()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PgR_pgnai8m7" | |
}, | |
"source": [ | |
"%timeit batch_expmat_product(A,c,d,u, py=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "b4nDpj0ejeDJ" | |
}, | |
"source": [ | |
"%timeit batch_expmat_product(A,c,d,u)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XEImcIkSjeUD" | |
}, | |
"source": [ | |
"%timeit batch_expmat_product(A,c,d,u, omp=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gBHvKyoUjgoP" | |
}, | |
"source": [ | |
"%timeit batch_expmat_product(Acu,ccu,dcu,ucu)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment