Skip to content

Instantly share code, notes, and snippets.

@lgarrison
Created February 25, 2022 18:29
Show Gist options
  • Save lgarrison/3dcf371cb9edb891cef92fe3db5c7d33 to your computer and use it in GitHub Desktop.
Save lgarrison/3dcf371cb9edb891cef92fe3db5c7d33 to your computer and use it in GitHub Desktop.
optimize calc_fenv
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "c15d3118",
"metadata": {},
"source": [
"The original calc_fenv() gave wrong results with `parallel=True` becuase of a bad parfor fusion (a Numba bug). There is a minimal reproducer at the bottom of the notebook that demonstrates the issue. But it turns out that a small manipulation of the code avoids the bug entirely, allowing us to use `parallel=True` safely.\n",
"\n",
"If one is paranoid, one can also disable parfor fusion entirely with `njit(parallel=dict(fusion=False))`, at some performance penalty.\n",
"\n",
"The original Numba bug will be fixed by: https://github.com/numba/numba/pull/7582"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9abe8d6d-d0c2-4c1a-a911-831029816c34",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import njit\n",
"import numba"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f3ed5448",
"metadata": {},
"outputs": [],
"source": [
"@njit(parallel=True)\n",
"def calc_fenv_opt(Menv, mbins, halosM):\n",
" fenv_rank = np.zeros(len(Menv))\n",
" for ibin in numba.prange(len(mbins)-1):\n",
" mmask = (halosM > mbins[ibin]) & (halosM < mbins[ibin + 1])\n",
" Nmask = np.sum(mmask)\n",
" if Nmask > 1:\n",
" new_fenv_rank = Menv[mmask].argsort().argsort()\n",
" fenv_rank[mmask] = new_fenv_rank / (Nmask-1) - 0.5 # max rank is always Nmask - 1\n",
" return fenv_rank\n",
"\n",
"def calc_fenv_orig(Menv, mbins, halosM):\n",
" fenv_rank = np.zeros(len(Menv))\n",
" for ibin in numba.prange(len(mbins)-1):\n",
" mmask = (halosM > mbins[ibin]) & (halosM < mbins[ibin + 1])\n",
" if np.sum(mmask) > 1:\n",
" new_fenv_rank = Menv[mmask].argsort().argsort()\n",
" fenv_rank[mmask] = new_fenv_rank / np.max(new_fenv_rank) - 0.5\n",
" return fenv_rank\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ac5706fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.78 s ± 32.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"False\n",
"376 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"True\n"
]
}
],
"source": [
"N = 10**7\n",
"mbins = np.linspace(0, 1, 10)\n",
"rng = np.random.default_rng()\n",
"halosM = rng.random(N)\n",
"Menv = rng.random(N)\n",
"numba.set_num_threads(12)\n",
"\n",
"parallel_calc_fenv_orig = njit(parallel=True)(calc_fenv_orig)\n",
"serial_calc_fenv_orig = njit(parallel=False)(calc_fenv_orig)\n",
"\n",
"parres_orig = parallel_calc_fenv_orig(Menv, mbins, halosM)\n",
"%timeit global serres_orig; serres_orig = serial_calc_fenv_orig(Menv, mbins, halosM) # 2.8 sec\n",
"\n",
"print((parres_orig == serres_orig).all()) # False\n",
"\n",
"%timeit global parres_opt; parres_opt = calc_fenv_opt(Menv, mbins, halosM) # 380 ms\n",
"\n",
"print((parres_opt == serres_orig).all()) # True"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e6d57adb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"================================================================================\n",
" Parallel Accelerator Optimizing: Function calc_fenv_orig, \n",
"/tmp/ipykernel_883107/1823462690.py (12) \n",
"================================================================================\n",
"\n",
"\n",
"Parallel loop listing for Function calc_fenv_orig, /tmp/ipykernel_883107/1823462690.py (12) \n",
"---------------------------------------------------------------------------------|loop #ID\n",
"def calc_fenv_orig(Menv, mbins, halosM): | \n",
" fenv_rank = np.zeros(len(Menv))----------------------------------------------| #0\n",
" for ibin in numba.prange(len(mbins)-1):--------------------------------------| #5\n",
" mmask = (halosM > mbins[ibin]) & (halosM < mbins[ibin + 1])--------------| #1\n",
" if np.sum(mmask) > 1:----------------------------------------------------| #3\n",
" new_fenv_rank = Menv[mmask].argsort().argsort() | \n",
" fenv_rank[mmask] = new_fenv_rank / np.max(new_fenv_rank) - 0.5-------| #2, 4\n",
" return fenv_rank | \n",
"------------------------------ After Optimisation ------------------------------\n",
"Parallel region 0:\n",
"+--5 (parallel)\n",
" +--1 (serial, fused with loop(s): 3)\n",
" +--4 (serial, fused with loop(s): 2)\n",
"\n",
"\n",
" \n",
"Parallel region 0 (loop #5) had 2 loop(s) fused and 2 loop(s) serialized as part\n",
" of the larger parallel loop (#5).\n",
"--------------------------------------------------------------------------------\n",
"--------------------------------------------------------------------------------\n",
" \n"
]
}
],
"source": [
"# for fun, we can clearly see the bad fusion of loops 2 & 4\n",
"parallel_calc_fenv_orig.parallel_diagnostics(level=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f731cc43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"# The original calc_fenv() had a bad parfor fusion due to a numba bug.\n",
"# f() is a minimal reproducer that demonstrates the issue.\n",
"# This will be fixed by: https://github.com/numba/numba/pull/7582\n",
"\n",
"def f():\n",
" a = np.arange(2)\n",
" amx = a.max()\n",
" res = np.empty(len(a))\n",
" res[:] = amx\n",
" return res\n",
"\n",
"numba.set_num_threads(1)\n",
"f_parallel = njit(parallel=True)(f)\n",
"f_serial = njit(parallel=False)(f)\n",
"\n",
"print(np.all(f_parallel() == f_serial()))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "31d477e5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"================================================================================\n",
" Parallel Accelerator Optimizing: Function f, \n",
"/tmp/ipykernel_883107/3926855332.py (5) \n",
"================================================================================\n",
"\n",
"\n",
"Parallel loop listing for Function f, /tmp/ipykernel_883107/3926855332.py (5) \n",
"------------------------------|loop #ID\n",
"def f(): | \n",
" a = np.arange(2) | \n",
" amx = a.max()-------------| #13\n",
" res = np.empty(len(a)) | \n",
" res[:] = amx--------------| #11\n",
" return res | \n",
"------------------------------ After Optimisation ------------------------------\n",
"Parallel region 0:\n",
"+--12 (parallel, fused with loop(s): 11, 13)\n",
"\n",
"\n",
" \n",
"Parallel region 0 (loop #12) had 2 loop(s) fused.\n",
"--------------------------------------------------------------------------------\n",
"--------------------------------------------------------------------------------\n",
" \n"
]
}
],
"source": [
"# as before, loops 11 & 13 have a bad fusion\n",
"f_parallel.parallel_diagnostics(level=1)"
]
}
],
"metadata": {
"interpreter": {
"hash": "5cd5cbe25001faa61ab76a271aef4113321f63d42e31cbdebc9d4a65270c2765"
},
"kernelspec": {
"display_name": "MyEnv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment