Skip to content

Instantly share code, notes, and snippets.

@dionhaefner
Created March 23, 2020 14:19
Show Gist options
  • Save dionhaefner/a97ef80b77e02b36e4b248bb97541161 to your computer and use it in GitHub Desktop.
Save dionhaefner/a97ef80b77e02b36e4b248bb97541161 to your computer and use it in GitHub Desktop.
Benchmarks of the tridiagonal matrix algorithm in Python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tridiagonal matrix solver benchmarks"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: XLA_PYTHON_CLIENT_PREALLOCATE=false\n",
"env: OMP_NUM_THREADS=1\n",
"env: CUDA_VISIBLE_DEVICES=0\n"
]
}
],
"source": [
"%env XLA_PYTHON_CLIENT_PREALLOCATE=false\n",
"%env OMP_NUM_THREADS=1\n",
"%env CUDA_VISIBLE_DEVICES=0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mon Mar 23 15:02:13 2020 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:03:00.0 Off | 0 |\n",
"| N/A 24C P0 29W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 1 Tesla P100-PCIE... Off | 00000000:04:00.0 Off | 0 |\n",
"| N/A 25C P0 28W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"import numba as nb\n",
"import numba.cuda\n",
"from scipy.linalg import lapack"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"jax.config.update('jax_enable_x64', True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"shape = (360, 160, 115)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implement TDMA\n",
"\n",
"#### NumPy"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def tdma_naive(a, b, c, d):\n",
" \"\"\"\n",
" Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.\n",
" \"\"\"\n",
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n",
"\n",
" n = a.shape[-1]\n",
"\n",
" for i in range(1, n):\n",
" w = a[..., i] / b[..., i - 1]\n",
" b[..., i] += -w * c[..., i - 1]\n",
" d[..., i] += -w * d[..., i - 1]\n",
"\n",
" out = np.empty_like(a)\n",
" out[..., -1] = d[..., -1] / b[..., -1]\n",
"\n",
" for i in range(n - 2, -1, -1):\n",
" out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]\n",
"\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Lapack"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def tdma_lapack(a, b, c, d):\n",
" a[..., 0] = c[..., -1] = 0 # remove couplings between slices\n",
" return lapack.dgtsv(a.flatten()[1:], b.flatten(), c.flatten()[:-1], d.flatten())[3].reshape(a.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Numba CPU"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"@nb.guvectorize([(nb.float64[:],) * 5], '(n), (n), (n), (n) -> (n)', nopython=True)\n",
"def tdma_numba(a, b, c, d, out):\n",
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n",
"\n",
" n = a.shape[0]\n",
"\n",
" for i in range(1, n):\n",
" w = a[i] / b[i - 1]\n",
" b[i] += -w * c[i - 1]\n",
" d[i] += -w * d[i - 1]\n",
"\n",
" out[-1] = d[-1] / b[-1]\n",
"\n",
" for i in range(n - 2, -1, -1):\n",
" out[i] = (d[i] - c[i] * out[i + 1]) / b[i]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Numba CUDA"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"nconst = shape[-1]\n",
"\n",
"\n",
"@nb.cuda.jit()\n",
"def tdma_numba_cuda_kernel(a, b, c, d, out):\n",
" i, j = nb.cuda.grid(2)\n",
" \n",
" if not(i < a.shape[0] and j < a.shape[1]):\n",
" return\n",
"\n",
" n = a.shape[2]\n",
" \n",
" cp = nb.cuda.local.array((nconst,), dtype=nb.float64)\n",
" dp = nb.cuda.local.array((nconst,), dtype=nb.float64)\n",
" \n",
" cp[0] = c[i, j, 0] / b[i, j, 0]\n",
" dp[0] = d[i, j, 0] / b[i, j, 0]\n",
" \n",
" for k in range(1, n):\n",
" norm_factor = b[i, j, k] - a[i, j, k] * cp[k-1]\n",
" cp[k] = c[i, j, k] / norm_factor\n",
" dp[k] = (d[i, j, k] - a[i, j, k] * dp[k-1]) / norm_factor\n",
"\n",
" out[i, j, n-1] = dp[n-1]\n",
"\n",
" for k in range(n - 2, -1, -1):\n",
" out[i, j, k] = dp[k] - cp[k] * out[i, j, k+1]\n",
" \n",
"\n",
"def tdma_numba_cuda(a, b, c, d):\n",
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n",
"\n",
" threadsperblock = (16, 16)\n",
" blockspergrid_x = math.ceil(a.shape[0] / threadsperblock[0])\n",
" blockspergrid_y = math.ceil(a.shape[1] / threadsperblock[1])\n",
" blockspergrid = (blockspergrid_x, blockspergrid_y)\n",
"\n",
" out = nb.cuda.device_array(a.shape, dtype=a.dtype)\n",
" tdma_numba_cuda_kernel[blockspergrid, threadsperblock](a, b, c, d, out)\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### JAX"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax.lax\n",
"\n",
"\n",
"def tdma_jax_kernel(a, b, c, d):\n",
" def compute_primes(last_primes, x):\n",
" last_cp, last_dp = last_primes\n",
" a, b, c, d = x\n",
"\n",
" denom = 1. / (b - a * last_cp)\n",
" cp = c * denom\n",
" dp = (d - a * last_dp) * denom\n",
"\n",
" new_primes = (cp, dp)\n",
" return new_primes, new_primes\n",
"\n",
" diags = (a.T, b.T, c.T, d.T)\n",
" init = jnp.zeros((a.shape[1], a.shape[0]))\n",
" _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)\n",
"\n",
" def backsubstitution(last_x, x):\n",
" cp, dp = x\n",
" new_x = dp - cp * last_x\n",
" return new_x, new_x\n",
"\n",
" _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))\n",
"\n",
" return sol[::-1].T\n",
"\n",
"\n",
"tdma_jax = jax.jit(tdma_jax_kernel, backend='cpu')\n",
"tdma_jax_cuda = jax.jit(tdma_jax_kernel, backend='gpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### CuPy"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import cupy"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from string import Template\n",
"\n",
"kernel = Template('''\n",
"extern \"C\" __global__\n",
"void execute(\n",
" const ${DTYPE} *a,\n",
" const ${DTYPE} *b,\n",
" const ${DTYPE} *c,\n",
" const ${DTYPE} *d,\n",
" ${DTYPE} *solution\n",
"){\n",
" const size_t m = ${SYS_DEPTH};\n",
" const size_t total_size = ${SIZE};\n",
" const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * m;\n",
"\n",
" if (idx >= total_size) {\n",
" return;\n",
" }\n",
"\n",
" ${DTYPE} cp[${SYS_DEPTH}];\n",
" ${DTYPE} dp[${SYS_DEPTH}];\n",
"\n",
" cp[0] = c[idx] / b[idx];\n",
" dp[0] = d[idx] / b[idx];\n",
"\n",
" for (ptrdiff_t j = 1; j < m; ++j) {\n",
" const ${DTYPE} norm_factor = b[idx+j] - a[idx+j] * cp[j-1];\n",
" cp[j] = c[idx+j] / norm_factor;\n",
" dp[j] = (d[idx+j] - a[idx+j] * dp[j-1]) / norm_factor;\n",
" }\n",
"\n",
" solution[idx + m-1] = dp[m-1];\n",
" for (ptrdiff_t j=m-2; j >= 0; --j) {\n",
" solution[idx + j] = dp[j] - cp[j] * solution[idx + j+1];\n",
" }\n",
"}\n",
"''').substitute(\n",
" DTYPE='double',\n",
" SYS_DEPTH=shape[-1],\n",
" SIZE=np.product(shape)\n",
")\n",
"\n",
"tdma_cupy_kernel = cupy.RawKernel(kernel, 'execute')\n",
"\n",
"\n",
"def tdma_cupy(a, b, c, d, blocksize=256):\n",
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n",
"\n",
" a, b, c, d = (cupy.asarray(k) for k in (a, b, c, d))\n",
" out = cupy.empty(a.shape, dtype=a.dtype)\n",
" \n",
" tdma_cupy_kernel(\n",
" (math.ceil(a.size / a.shape[-1] / blocksize),),\n",
" (blocksize,),\n",
" (a, b, c, d, out)\n",
" )\n",
" \n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check results"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✔️\n",
"✔️\n",
"✔️\n",
"✔️\n",
"✔️\n",
"✔️\n"
]
}
],
"source": [
"np.random.seed(17)\n",
"a, b, c, d = np.random.randn(4, *shape)\n",
"res_naive = tdma_naive(a, b, c, d)\n",
"\n",
"for imp in (tdma_cupy, tdma_lapack, tdma_numba, tdma_numba_cuda, tdma_jax, tdma_jax_cuda):\n",
" np.random.seed(17)\n",
" a, b, c, d = np.random.randn(4, *shape)\n",
" out = imp(a, b, c, d)\n",
" \n",
" try:\n",
" out = out.get()\n",
" except AttributeError:\n",
" pass\n",
"\n",
" np.testing.assert_allclose(out, res_naive)\n",
" print('✔️')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(17)\n",
"a, b, c, d = np.random.randn(4, *shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CPU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### NumPy"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"403 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_naive(a, b, c, d)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Lapack"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"317 ms ± 7.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_lapack(a, b, c, d)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Numba"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"130 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_numba(a, b, c, d)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### JAX"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"325 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_jax(a, b, c, d).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### GPU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Numba"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"ac, bc, cc, dc = (nb.cuda.to_device(k) for k in (a, b, c, d))\n",
"tdma_numba_cuda(ac, bc, cc, dc); # trigger compilation"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13.2 ms ± 67.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_numba_cuda(ac, bc, cc, dc)\n",
"numba.cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### JAX"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"aj, bj, cj, dj = (jnp.array(k).block_until_ready() for k in (a, b, c, d))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14 ms ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_jax_cuda(aj, bj, cj, dj).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### CuPy"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"stream = cupy.cuda.stream.Stream()\n",
"\n",
"with stream:\n",
" ac, bc, cc, dc = (cupy.asarray(k) for k in (a, b, c, d))\n",
" tdma_cupy(ac, bc, cc, dc); # trigger compilation\n",
"\n",
"stream.synchronize()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.06 ms ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"with stream:\n",
" tdma_cupy(ac, bc, cc, dc)\n",
"\n",
"stream.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Try Jax without transposes"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def tdma_jax_kernel_notrans(a, b, c, d):\n",
" def compute_primes(last_primes, x):\n",
" last_cp, last_dp = last_primes\n",
" a, b, c, d = x\n",
"\n",
" denom = 1. / (b - a * last_cp)\n",
" cp = c * denom\n",
" dp = (d - a * last_dp) * denom\n",
"\n",
" new_primes = (cp, dp)\n",
" return new_primes, new_primes\n",
"\n",
" diags = (a, b, c, d)\n",
" init = jnp.zeros((a.shape[1], a.shape[2]))\n",
" _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)\n",
"\n",
" def backsubstitution(last_x, x):\n",
" cp, dp = x\n",
" new_x = dp - cp * last_x\n",
" return new_x, new_x\n",
"\n",
" _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))\n",
"\n",
" return sol[::-1]\n",
"\n",
"\n",
"tdma_jax_notrans = jax.jit(tdma_jax_kernel_notrans, backend='cpu')\n",
"tdma_jax_cuda_notrans = jax.jit(tdma_jax_kernel_notrans, backend='gpu')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"at, bt, ct, dt = (k.T for k in (a, b, c, d))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"537 ms ± 3.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_jax_notrans(at, bt, ct, dt).block_until_ready()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"aj, bj, cj, dj = (jnp.array(k.T).block_until_ready() for k in (a, b, c, d))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.1 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"tdma_jax_cuda_notrans(aj, bj, cj, dj).block_until_ready()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment