Created
March 23, 2020 14:19
-
-
Save dionhaefner/a97ef80b77e02b36e4b248bb97541161 to your computer and use it in GitHub Desktop.
Benchmarks of the tridiagonal matrix algorithm in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Tridiagonal matrix solver benchmarks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"env: XLA_PYTHON_CLIENT_PREALLOCATE=false\n", | |
"env: OMP_NUM_THREADS=1\n", | |
"env: CUDA_VISIBLE_DEVICES=0\n" | |
] | |
} | |
], | |
"source": [ | |
"%env XLA_PYTHON_CLIENT_PREALLOCATE=false\n", | |
"%env OMP_NUM_THREADS=1\n", | |
"%env CUDA_VISIBLE_DEVICES=0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Mon Mar 23 15:02:13 2020 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla P100-PCIE... Off | 00000000:03:00.0 Off | 0 |\n", | |
"| N/A 24C P0 29W / 250W | 0MiB / 16280MiB | 0% Default |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
"| 1 Tesla P100-PCIE... Off | 00000000:04:00.0 Off | 0 |\n", | |
"| N/A 25C P0 28W / 250W | 0MiB / 16280MiB | 0% Default |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: GPU Memory |\n", | |
"| GPU PID Type Process name Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
] | |
} | |
], | |
"source": [ | |
"!nvidia-smi" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"import numpy as np\n", | |
"import numba as nb\n", | |
"import numba.cuda\n", | |
"from scipy.linalg import lapack" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"jax.config.update('jax_enable_x64', True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"shape = (360, 160, 115)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Implement TDMA\n", | |
"\n", | |
"#### NumPy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tdma_naive(a, b, c, d):\n", | |
" \"\"\"\n", | |
" Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.\n", | |
" \"\"\"\n", | |
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n", | |
"\n", | |
" n = a.shape[-1]\n", | |
"\n", | |
" for i in range(1, n):\n", | |
" w = a[..., i] / b[..., i - 1]\n", | |
" b[..., i] += -w * c[..., i - 1]\n", | |
" d[..., i] += -w * d[..., i - 1]\n", | |
"\n", | |
" out = np.empty_like(a)\n", | |
" out[..., -1] = d[..., -1] / b[..., -1]\n", | |
"\n", | |
" for i in range(n - 2, -1, -1):\n", | |
" out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]\n", | |
"\n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Lapack" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tdma_lapack(a, b, c, d):\n", | |
" a[..., 0] = c[..., -1] = 0 # remove couplings between slices\n", | |
" return lapack.dgtsv(a.flatten()[1:], b.flatten(), c.flatten()[:-1], d.flatten())[3].reshape(a.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Numba CPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@nb.guvectorize([(nb.float64[:],) * 5], '(n), (n), (n), (n) -> (n)', nopython=True)\n", | |
"def tdma_numba(a, b, c, d, out):\n", | |
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n", | |
"\n", | |
" n = a.shape[0]\n", | |
"\n", | |
" for i in range(1, n):\n", | |
" w = a[i] / b[i - 1]\n", | |
" b[i] += -w * c[i - 1]\n", | |
" d[i] += -w * d[i - 1]\n", | |
"\n", | |
" out[-1] = d[-1] / b[-1]\n", | |
"\n", | |
" for i in range(n - 2, -1, -1):\n", | |
" out[i] = (d[i] - c[i] * out[i + 1]) / b[i]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Numba CUDA" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"nconst = shape[-1]\n", | |
"\n", | |
"\n", | |
"@nb.cuda.jit()\n", | |
"def tdma_numba_cuda_kernel(a, b, c, d, out):\n", | |
" i, j = nb.cuda.grid(2)\n", | |
" \n", | |
" if not(i < a.shape[0] and j < a.shape[1]):\n", | |
" return\n", | |
"\n", | |
" n = a.shape[2]\n", | |
" \n", | |
" cp = nb.cuda.local.array((nconst,), dtype=nb.float64)\n", | |
" dp = nb.cuda.local.array((nconst,), dtype=nb.float64)\n", | |
" \n", | |
" cp[0] = c[i, j, 0] / b[i, j, 0]\n", | |
" dp[0] = d[i, j, 0] / b[i, j, 0]\n", | |
" \n", | |
" for k in range(1, n):\n", | |
" norm_factor = b[i, j, k] - a[i, j, k] * cp[k-1]\n", | |
" cp[k] = c[i, j, k] / norm_factor\n", | |
" dp[k] = (d[i, j, k] - a[i, j, k] * dp[k-1]) / norm_factor\n", | |
"\n", | |
" out[i, j, n-1] = dp[n-1]\n", | |
"\n", | |
" for k in range(n - 2, -1, -1):\n", | |
" out[i, j, k] = dp[k] - cp[k] * out[i, j, k+1]\n", | |
" \n", | |
"\n", | |
"def tdma_numba_cuda(a, b, c, d):\n", | |
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n", | |
"\n", | |
" threadsperblock = (16, 16)\n", | |
" blockspergrid_x = math.ceil(a.shape[0] / threadsperblock[0])\n", | |
" blockspergrid_y = math.ceil(a.shape[1] / threadsperblock[1])\n", | |
" blockspergrid = (blockspergrid_x, blockspergrid_y)\n", | |
"\n", | |
" out = nb.cuda.device_array(a.shape, dtype=a.dtype)\n", | |
" tdma_numba_cuda_kernel[blockspergrid, threadsperblock](a, b, c, d, out)\n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### JAX" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax.numpy as jnp\n", | |
"import jax.lax\n", | |
"\n", | |
"\n", | |
"def tdma_jax_kernel(a, b, c, d):\n", | |
" def compute_primes(last_primes, x):\n", | |
" last_cp, last_dp = last_primes\n", | |
" a, b, c, d = x\n", | |
"\n", | |
" denom = 1. / (b - a * last_cp)\n", | |
" cp = c * denom\n", | |
" dp = (d - a * last_dp) * denom\n", | |
"\n", | |
" new_primes = (cp, dp)\n", | |
" return new_primes, new_primes\n", | |
"\n", | |
" diags = (a.T, b.T, c.T, d.T)\n", | |
" init = jnp.zeros((a.shape[1], a.shape[0]))\n", | |
" _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)\n", | |
"\n", | |
" def backsubstitution(last_x, x):\n", | |
" cp, dp = x\n", | |
" new_x = dp - cp * last_x\n", | |
" return new_x, new_x\n", | |
"\n", | |
" _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))\n", | |
"\n", | |
" return sol[::-1].T\n", | |
"\n", | |
"\n", | |
"tdma_jax = jax.jit(tdma_jax_kernel, backend='cpu')\n", | |
"tdma_jax_cuda = jax.jit(tdma_jax_kernel, backend='gpu')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### CuPy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import cupy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from string import Template\n", | |
"\n", | |
"kernel = Template('''\n", | |
"extern \"C\" __global__\n", | |
"void execute(\n", | |
" const ${DTYPE} *a,\n", | |
" const ${DTYPE} *b,\n", | |
" const ${DTYPE} *c,\n", | |
" const ${DTYPE} *d,\n", | |
" ${DTYPE} *solution\n", | |
"){\n", | |
" const size_t m = ${SYS_DEPTH};\n", | |
" const size_t total_size = ${SIZE};\n", | |
" const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * m;\n", | |
"\n", | |
" if (idx >= total_size) {\n", | |
" return;\n", | |
" }\n", | |
"\n", | |
" ${DTYPE} cp[${SYS_DEPTH}];\n", | |
" ${DTYPE} dp[${SYS_DEPTH}];\n", | |
"\n", | |
" cp[0] = c[idx] / b[idx];\n", | |
" dp[0] = d[idx] / b[idx];\n", | |
"\n", | |
" for (ptrdiff_t j = 1; j < m; ++j) {\n", | |
" const ${DTYPE} norm_factor = b[idx+j] - a[idx+j] * cp[j-1];\n", | |
" cp[j] = c[idx+j] / norm_factor;\n", | |
" dp[j] = (d[idx+j] - a[idx+j] * dp[j-1]) / norm_factor;\n", | |
" }\n", | |
"\n", | |
" solution[idx + m-1] = dp[m-1];\n", | |
" for (ptrdiff_t j=m-2; j >= 0; --j) {\n", | |
" solution[idx + j] = dp[j] - cp[j] * solution[idx + j+1];\n", | |
" }\n", | |
"}\n", | |
"''').substitute(\n", | |
" DTYPE='double',\n", | |
" SYS_DEPTH=shape[-1],\n", | |
" SIZE=np.product(shape)\n", | |
")\n", | |
"\n", | |
"tdma_cupy_kernel = cupy.RawKernel(kernel, 'execute')\n", | |
"\n", | |
"\n", | |
"def tdma_cupy(a, b, c, d, blocksize=256):\n", | |
" assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape\n", | |
"\n", | |
" a, b, c, d = (cupy.asarray(k) for k in (a, b, c, d))\n", | |
" out = cupy.empty(a.shape, dtype=a.dtype)\n", | |
" \n", | |
" tdma_cupy_kernel(\n", | |
" (math.ceil(a.size / a.shape[-1] / blocksize),),\n", | |
" (blocksize,),\n", | |
" (a, b, c, d, out)\n", | |
" )\n", | |
" \n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Check results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"✔️\n", | |
"✔️\n", | |
"✔️\n", | |
"✔️\n", | |
"✔️\n", | |
"✔️\n" | |
] | |
} | |
], | |
"source": [ | |
"np.random.seed(17)\n", | |
"a, b, c, d = np.random.randn(4, *shape)\n", | |
"res_naive = tdma_naive(a, b, c, d)\n", | |
"\n", | |
"for imp in (tdma_cupy, tdma_lapack, tdma_numba, tdma_numba_cuda, tdma_jax, tdma_jax_cuda):\n", | |
" np.random.seed(17)\n", | |
" a, b, c, d = np.random.randn(4, *shape)\n", | |
" out = imp(a, b, c, d)\n", | |
" \n", | |
" try:\n", | |
" out = out.get()\n", | |
" except AttributeError:\n", | |
" pass\n", | |
"\n", | |
" np.testing.assert_allclose(out, res_naive)\n", | |
" print('✔️')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Benchmark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(17)\n", | |
"a, b, c, d = np.random.randn(4, *shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### CPU" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### NumPy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"403 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_naive(a, b, c, d)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Lapack" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"317 ms ± 7.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_lapack(a, b, c, d)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Numba" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"130 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_numba(a, b, c, d)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### JAX" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"325 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_jax(a, b, c, d).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### GPU" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Numba" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ac, bc, cc, dc = (nb.cuda.to_device(k) for k in (a, b, c, d))\n", | |
"tdma_numba_cuda(ac, bc, cc, dc); # trigger compilation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"13.2 ms ± 67.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_numba_cuda(ac, bc, cc, dc)\n", | |
"numba.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### JAX" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"aj, bj, cj, dj = (jnp.array(k).block_until_ready() for k in (a, b, c, d))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"14 ms ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_jax_cuda(aj, bj, cj, dj).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### CuPy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"stream = cupy.cuda.stream.Stream()\n", | |
"\n", | |
"with stream:\n", | |
" ac, bc, cc, dc = (cupy.asarray(k) for k in (a, b, c, d))\n", | |
" tdma_cupy(ac, bc, cc, dc); # trigger compilation\n", | |
"\n", | |
"stream.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"5.06 ms ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"with stream:\n", | |
" tdma_cupy(ac, bc, cc, dc)\n", | |
"\n", | |
"stream.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Try Jax without transposes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tdma_jax_kernel_notrans(a, b, c, d):\n", | |
" def compute_primes(last_primes, x):\n", | |
" last_cp, last_dp = last_primes\n", | |
" a, b, c, d = x\n", | |
"\n", | |
" denom = 1. / (b - a * last_cp)\n", | |
" cp = c * denom\n", | |
" dp = (d - a * last_dp) * denom\n", | |
"\n", | |
" new_primes = (cp, dp)\n", | |
" return new_primes, new_primes\n", | |
"\n", | |
" diags = (a, b, c, d)\n", | |
" init = jnp.zeros((a.shape[1], a.shape[2]))\n", | |
" _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)\n", | |
"\n", | |
" def backsubstitution(last_x, x):\n", | |
" cp, dp = x\n", | |
" new_x = dp - cp * last_x\n", | |
" return new_x, new_x\n", | |
"\n", | |
" _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))\n", | |
"\n", | |
" return sol[::-1]\n", | |
"\n", | |
"\n", | |
"tdma_jax_notrans = jax.jit(tdma_jax_kernel_notrans, backend='cpu')\n", | |
"tdma_jax_cuda_notrans = jax.jit(tdma_jax_kernel_notrans, backend='gpu')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"at, bt, ct, dt = (k.T for k in (a, b, c, d))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"537 ms ± 3.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_jax_notrans(at, bt, ct, dt).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"aj, bj, cj, dj = (jnp.array(k.T).block_until_ready() for k in (a, b, c, d))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"10.1 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"tdma_jax_cuda_notrans(aj, bj, cj, dj).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment