Skip to content

Instantly share code, notes, and snippets.

@RutgerK
Created March 14, 2019 16:27
Show Gist options
  • Save RutgerK/352839af57fd986a88566803669cc8cc to your computer and use it in GitHub Desktop.
Save RutgerK/352839af57fd986a88566803669cc8cc to your computer and use it in GitHub Desktop.
Numba heapq test
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numba\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from heapq import heappush, heappop\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@numba.njit\n",
"def aldous_broder(ys=20, xs=20): \n",
"\n",
" x_start = np.random.randint(xs)\n",
" y_start = np.random.randint(ys)\n",
" \n",
" maze = np.zeros((ys*2+1, xs*2+1))\n",
" \n",
" neighbors = [[-2, 0], # up\n",
" [ 0, 2], # right\n",
" [ 2 ,0], # down\n",
" [ 0,-2]] # left\n",
" \n",
" # current position\n",
" y = y_start * 2 + 1\n",
" x = x_start * 2 + 1\n",
" \n",
" # neighbros \n",
" yp = y\n",
" xp = x\n",
" \n",
" n_filled = 0\n",
" while n_filled < (ys*xs):\n",
" \n",
" if maze[y, x] == 0:\n",
" \n",
" maze[y, x] = 1\n",
" \n",
" if n_filled > 0:\n",
" # since where taking steps of 2 (to get \"walls\")\n",
" # \"break\" the wall with the previous step\n",
" if yp == y:\n",
" maze[yp, (x+xp)//2] = 1\n",
" elif xp == x:\n",
" maze[(y+yp)//2, xp] = 1\n",
" \n",
" n_filled += 1\n",
" \n",
" yp = y\n",
" xp = x\n",
" \n",
" # get the (potential) neighbors of the current position\n",
" nb = [(y + y_off, x + x_off) for y_off, x_off in neighbors]\n",
" \n",
" while len(nb) > 0:\n",
" # pick random neighbor\n",
" yn, xn = nb.pop(np.random.randint(0,len(nb)) )\n",
" \n",
" # out of bounds\n",
" if yn < 0 or yn//2 > ys-1:\n",
" continue\n",
" if xn < 0 or xn//2 > xs-1:\n",
" continue\n",
"\n",
" y = yn\n",
" x = xn\n",
" \n",
" if maze[y, x] == 1:\n",
" yp = y\n",
" xp = x\n",
" \n",
" break\n",
" \n",
" return maze"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def dijkstra_distance_python(grid, x, y):\n",
" \"\"\"Simple dijkstra distance calculate the distance\n",
" for an entire grid (not just shortest)\n",
" \n",
" grid: array defining the grid, \n",
" zeros are passable\n",
" ones are impassable (eg walls)\n",
" \n",
" x: starting x index\n",
" y: starting y index \n",
" \"\"\"\n",
" \n",
" ys, xs = grid.shape\n",
" start_dist = 0.0\n",
" \n",
" neighbors = [[-1, 0], # up\n",
" [ 0, 1], # right\n",
" [ 1 ,0], # down\n",
" [ 0,-1]] # left\n",
" \n",
" distance = np.full_like(grid, np.inf, dtype=np.float32) \n",
" frontier = [(start_dist, y, x)]\n",
" \n",
" while len(frontier) > 0:\n",
" dist, y, x = heappop(frontier)\n",
" distance[y,x] = dist\n",
" \n",
" for y_off, x_off in neighbors:\n",
" \n",
" # neighbor location\n",
" yn = y + y_off\n",
" xn = x + x_off\n",
" \n",
" # out of bounds\n",
" if yn < 0 or yn > ys-1:\n",
" continue\n",
" if xn < 0 or xn > xs-1:\n",
" continue\n",
" \n",
" # check if passable\n",
" if grid[yn,xn] == 0:\n",
" continue\n",
" \n",
" # neighbor distance\n",
" ndist = distance[y,x] + 1\n",
" \n",
" if np.isinf(distance[yn, xn]) or ndist < distance[yn,xn]:\n",
" distance[yn,xn] = ndist\n",
" heappush(frontier, (ndist, yn, xn))\n",
"\n",
" return distance"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"@numba.njit\n",
"def get_path_from(distance, x_end=-2, y_end=-2):\n",
" \"\"\"Get the shortest path by walking in reverse from an end point\"\"\"\n",
" \n",
" ys, xs = distance.shape\n",
"\n",
" # start of search\n",
" y = y_end % ys\n",
" x = x_end % xs\n",
" \n",
" dist = distance[y,x]\n",
" path = [(y,x)]\n",
" \n",
" neighbors = [[-1, 0], # up\n",
" [ 0, 1], # right\n",
" [ 1 ,0], # down\n",
" [ 0,-1]] # left\n",
" \n",
" while True:\n",
" \n",
" not_smaller = 0 \n",
" \n",
" for y_off, x_off in neighbors:\n",
"\n",
" yn = y + y_off\n",
" xn = x + x_off\n",
"\n",
" # out of bounds\n",
" if xn < 0 or xn > xs-1:\n",
" continue\n",
" if yn < 0 or yn > ys-1:\n",
" continue\n",
"\n",
" ndist = distance[yn, xn]\n",
"\n",
" if ndist < dist:\n",
" dist = ndist\n",
" yn_best = yn\n",
" xn_best = xn\n",
" else:\n",
" not_smaller += 1\n",
"\n",
" y = yn_best\n",
" x = xn_best\n",
" path.append((y, x))\n",
" \n",
" if not_smaller == 4:\n",
" break\n",
" \n",
" return path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compile to Numba"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dijkstra_distance_numba = numba.njit()(dijkstra_distance_python)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate Maze (Aldous-Broder algorithm)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# generate maze\n",
"maze_dims = (50, 50)\n",
"maze = aldous_broder(*maze_dims)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dijkstra distance (for entire grid)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# start in the upper-left corner\n",
"dist_py = dijkstra_distance_python(maze, 1, 1)\n",
"dist_num = dijkstra_distance_numba(maze, 1, 1)\n",
"\n",
"np.testing.assert_array_equal(dist_num, dist_py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test performance"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63.5 ms ± 2.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"1.25 ms ± 98.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
},
{
"data": {
"text/plain": [
"'Numba speedup: 50.79x'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t_py = %timeit -o dist_py = dijkstra_distance_python(maze, 1, 1)\n",
"t_num = %timeit -o dist_num = dijkstra_distance_numba(maze, 1, 1)\n",
"\n",
"'Numba speedup: {:1.2f}x'.format(t_py.best / t_num.best)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Extract shortest path"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# start at bottom-right corner\n",
"shortest_path = get_path_from(dist_num, x_end=-2, y_end=-2)\n",
"yverts, xverts = list(zip(*shortest_path))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot result"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2b3090fe748>]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x864 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1,1, figsize=(12, 12), constrained_layout=True, subplot_kw=dict(xticks=[], yticks=[]))\n",
"\n",
"cmap = plt.cm.viridis\n",
"cmap.set_bad('k')\n",
"\n",
"ax.imshow(dist_num, origin='upper', cmap=cmap)\n",
"ax.plot(xverts, yverts, 'w-', lw=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment