Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created July 18, 2020 18:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/6826d02949e4d2ce82122a8bd5c62cf7 to your computer and use it in GitHub Desktop.
Save shoyer/6826d02949e4d2ce82122a8bd5c62cf7 to your computer and use it in GitHub Desktop.
Poisson pytree benchmark
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Poisson pytree benchmark",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/6826d02949e4d2ce82122a8bd5c62cf7/poisson-pytree-benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xBlNAa_KLVlO",
"colab_type": "text"
},
"source": [
"# Compare the performance of pytree vs raveled `cg` for solving the Poisson equation\n",
"\n",
"https://github.com/google/jax/issues/1531"
]
},
{
"cell_type": "code",
"metadata": {
"id": "egEvCSidFx-y",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import scipy.sparse.linalg\n",
"import jax\n",
"from functools import partial\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def axis_slice(ndim, index, axis):\n",
" slices = [slice(None)] * ndim\n",
" slices[axis] = index\n",
" return tuple(slices)\n",
"\n",
"def slice_along_axis(array, index, axis):\n",
" return array[axis_slice(array.ndim, index, axis)]\n",
"\n",
"def shift(array, offset, axis):\n",
" index = slice(offset, None) if offset >= 0 else slice(None, offset)\n",
" sliced = slice_along_axis(array, index, axis)\n",
" padding = [(0, 0)] * array.ndim\n",
" padding[axis] = (-min(offset, 0), max(offset, 0))\n",
" return jnp.pad(sliced, padding, mode='constant', constant_values=0)\n",
"\n",
"def laplacian(array):\n",
" # note: I believe this is faster than a convolution (at least on most platforms)\n",
" left = shift(array, +1, axis=0)\n",
" right = shift(array, -1, axis=0)\n",
" up = shift(array, +1, axis=1)\n",
" down = shift(array, -1, axis=1)\n",
" convolved = -(left + right + up + down) + 4 * array\n",
" return convolved\n",
"\n",
"def laplacian_flat(array):\n",
" size = int(array.shape[0] ** 0.5)\n",
" array = array.reshape(size, size)\n",
" convolved = laplacian(array)\n",
" return convolved.reshape(-1)\n",
"\n",
"def make_source(shape):\n",
" assert len(shape) == 2\n",
" x = np.linspace(0, 1, num=shape[0])\n",
" y = np.linspace(0, 1, num=shape[1])\n",
" source = np.zeros(shape)\n",
" source[0, :] = x\n",
" source[-1, :] = x\n",
" source[:, 0] = 4 * y * (1 - y)\n",
" source[:, -1] = 1 - 4 * y * (1 - y)\n",
" return source\n",
"\n",
"# The functions we'll be benchmarking\n",
"def jax_poisson_cg_solve(b, x0):\n",
" solution, info = jax.scipy.sparse.linalg.cg(\n",
" laplacian, b, x0, tol=0, atol=0, maxiter=MAX_ITER)\n",
" return solution\n",
"\n",
"def jax_poisson_cg_solve_flat(b, x0):\n",
" solution, info = jax.scipy.sparse.linalg.cg(\n",
" laplacian_flat, b.ravel(), x0.ravel(), tol=0, atol=0, maxiter=MAX_ITER)\n",
" return solution\n",
"\n",
"# simulation parameters\n",
"MAX_ITER = 500\n",
"shape = (512, 512)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "3oz_t6TGLWYY",
"colab_type": "text"
},
"source": [
"# Benchmark SciPy"
]
},
{
"cell_type": "code",
"metadata": {
"id": "MOj6Cx9QGF7n",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "f83cec55-8993-40d4-987b-c134102bc160"
},
"source": [
"@jax.jit\n",
"def matvec(x):\n",
" return laplacian(jnp.reshape(x, shape)).ravel()\n",
"\n",
"source = make_source(shape)\n",
"x0 = np.zeros(shape).ravel()\n",
"b = -source.ravel()\n",
"A = scipy.sparse.linalg.LinearOperator(\n",
" (int(np.prod(shape)),) * 2, matvec, dtype=np.float32)\n",
"solution, info = scipy.sparse.linalg.cg(A, b, x0=x0, tol=0, atol=0, maxiter=MAX_ITER)\n",
"%timeit scipy.sparse.linalg.cg(A, b, x0=x0, tol=0, atol=0, maxiter=MAX_ITER)\n",
"print(f'Error: {np.linalg.norm(matvec(solution) - b)}')"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"1 loop, best of 3: 2.05 s per loop\n",
"Error: 0.037856701761484146\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TilfAxiKLX92",
"colab_type": "text"
},
"source": [
"# Benchmark solving with 2D arrays\n",
"\n",
"On CPU, we're about 2.5x faster than SciPy.\n",
"\n",
"On GPU, 30x faster still."
]
},
{
"cell_type": "code",
"metadata": {
"id": "QpObg6D3JYD1",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
},
"outputId": "bff04ec9-9f2c-4067-97d7-f370154c261d"
},
"source": [
"for backend in ['cpu', 'gpu']:\n",
" source = make_source(shape)\n",
" b = jnp.asarray(-source).block_until_ready()\n",
" x0 = jnp.zeros_like(source).block_until_ready()\n",
" cg_solve = jax.jit(jax_poisson_cg_solve, backend=backend)\n",
" solution = cg_solve(b, x0).block_until_ready()\n",
" print(f\"{backend.upper()} test:\")\n",
" %timeit cg_solve(b, x0).block_until_ready()\n",
" print(f'Error: {np.linalg.norm(laplacian(solution) + source)}')"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"CPU test:\n",
"1 loop, best of 3: 813 ms per loop\n",
"Error: 0.03532074764370918\n",
"GPU test:\n",
"10 loops, best of 3: 28.2 ms per loop\n",
"Error: 0.03786032646894455\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "duh2QcQcLcBu",
"colab_type": "text"
},
"source": [
"# Benchmark flattened arrays\n",
"\n",
"On CPU, we're 57% slower\n",
"\n",
"On GPU, we're 90% slower"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Gu5QQwcvMgjZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
},
"outputId": "68fc1b2d-2ca0-49ea-d914-ccd42aedbf04"
},
"source": [
"for backend in ['cpu', 'gpu']:\n",
" source = make_source(shape)\n",
" b = jnp.asarray(-source).block_until_ready()\n",
" x0 = jnp.zeros_like(source).block_until_ready()\n",
" cg_solve = jax.jit(jax_poisson_cg_solve_flat, backend=backend)\n",
" solution = cg_solve(b, x0).block_until_ready()\n",
" print(f\"{backend.upper()} test:\")\n",
" %timeit cg_solve(b, x0).block_until_ready()\n",
" print(f'Error: {np.linalg.norm(laplacian_flat(solution) + source.ravel())}')"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"CPU test:\n",
"1 loop, best of 3: 1.28 s per loop\n",
"Error: 0.03532460331916809\n",
"GPU test:\n",
"10 loops, best of 3: 53.6 ms per loop\n",
"Error: 0.03785990551114082\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iR0zyObSP_Aj",
"colab_type": "text"
},
"source": [
"## Make a pretty picture\n",
"\n",
"Verify that we did, indeed, approximately solve the Poisson equation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kEHiGHlDQg1-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 286
},
"outputId": "c477f1d3-59e0-4cc8-8759-ffe3ae28ed1f"
},
"source": [
"import matplotlib.pyplot as plt\n",
"plt.imshow(solution.reshape(shape))\n",
"plt.colorbar()"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f1e70ef0eb8>"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4-l8j-44QhPj",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment