Skip to content

Instantly share code, notes, and snippets.

@telegraphic
Created March 30, 2021 07:33
Show Gist options
  • Save telegraphic/1e08354b4f63673fdd2a75648ca50029 to your computer and use it in GitHub Desktop.
Save telegraphic/1e08354b4f63673fdd2a75648ca50029 to your computer and use it in GitHub Desktop.
Speedup frbpoppy notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "available-navigation",
"metadata": {},
"outputs": [],
"source": [
"from frbpoppy import precalc as pc\n",
"import healpy as hp\n",
"import numpy as np\n",
"from astropy.coordinates import Angle"
]
},
{
"cell_type": "markdown",
"id": "collaborative-pillow",
"metadata": {},
"source": [
"### Basic profiling of CosmicPopulation"
]
},
{
"cell_type": "code",
"execution_count": 149,
"id": "increased-grass",
"metadata": {},
"outputs": [],
"source": [
"from frbpoppy import CosmicPopulation"
]
},
{
"cell_type": "code",
"execution_count": 150,
"id": "precise-pontiac",
"metadata": {},
"outputs": [],
"source": [
"# Now generate FRB population\n",
"n_source = 1000\n",
"n_days = 1.0\n",
"cosmic_pop = CosmicPopulation.complex(n_source, n_days=n_days)"
]
},
{
"cell_type": "code",
"execution_count": 153,
"id": "realistic-andorra",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.8 µs ± 3.27 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"1.04 s ± 15.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"137 ns ± 1.52 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)\n",
"382 µs ± 3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"125 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"1.44 s ± 17.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"179 µs ± 86.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"75.8 µs ± 341 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"45.3 µs ± 280 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"%timeit cosmic_pop.gen_index()\n",
"%timeit cosmic_pop.gen_dist()\n",
"%timeit cosmic_pop.gen_time()\n",
"%timeit cosmic_pop.gen_direction()\n",
"%timeit cosmic_pop.gen_gal_coords()\n",
"%timeit cosmic_pop.gen_dm()\n",
"%timeit cosmic_pop.gen_w()\n",
"%timeit cosmic_pop.gen_lum()\n",
"%timeit cosmic_pop.gen_si()"
]
},
{
"cell_type": "markdown",
"id": "overall-johnson",
"metadata": {},
"source": [
"Two slowest functions are `gen_dist()` and `gen_dm()`. Target these for optimization"
]
},
{
"cell_type": "markdown",
"id": "previous-constraint",
"metadata": {},
"source": [
"## Speedup `gen_dm()`\n",
"\n",
"### Setup NE2001 lookup table\n",
"\n",
"First let's setup the current lookup system, which uses an sqlite table"
]
},
{
"cell_type": "code",
"execution_count": 189,
"id": "applicable-mineral",
"metadata": {},
"outputs": [],
"source": [
"ne2001 = pc.NE2001Table()\n",
"\n",
"def dm_func_sql(gl, gb):\n",
" z = lambda: ne2001.lookup(gl, gb)\n",
" return z()"
]
},
{
"cell_type": "markdown",
"id": "running-piece",
"metadata": {},
"source": [
"### Setup healpix-based lookup\n",
"\n",
"In lieu of sqlite, let's try using healpix coversions, which can be done by healpy ufuncs.\n",
"\n",
"The `hp.ang2pix` will convert a sky coordinate into a healpix pixel ID, and `hp.pix2ang` does the opposite.\n",
"\n",
"Using a precomputed NE2001 map."
]
},
{
"cell_type": "code",
"execution_count": 146,
"id": "republican-legislature",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NSIDE = 128\n",
"ORDERING = RING in fits file\n",
"INDXSCHM = IMPLICIT\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 612x388.8 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"data = hp.read_map('../data/models/healpix/dm-ne2001-30kpc.fits')\n",
"hp.mollview(data)\n",
"\n",
"nside = hp.npix2nside(len(data)) # This sets map resolution (NSIDE)\n",
"npix = hp.nside2npix(nside) # Compute number of pixels in healpix map\n",
"pid = np.arange(npix) # Give each pixel an index\n",
"\n",
"def dm_func_healpix(gl, gb):\n",
" pixloc = hp.ang2pix(nside, gl, gb, lonlat=True)\n",
" return data[pixloc]"
]
},
{
"cell_type": "markdown",
"id": "beautiful-distributor",
"metadata": {},
"source": [
"## Compare timings\n",
"\n",
"Create a grid of galactic (lon,lat) points for testing"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "binary-duncan",
"metadata": {},
"outputs": [],
"source": [
"gl = np.linspace(-180, 180, 360)\n",
"gb = np.linspace(-90, 90, 180)\n",
"\n",
"gl_grid, gb_grid = np.meshgrid(gl, gb)\n",
"gl_grid = gl_grid.ravel()\n",
"gb_grid = gb_grid.ravel()"
]
},
{
"cell_type": "markdown",
"id": "radical-building",
"metadata": {},
"source": [
"Now run timing test:"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "significant-translator",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.14 ms ± 16.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit dm_func_healpix(gl_grid, gb_grid)"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "sunset-kingston",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"746 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%timeit dm_func_sql(gl_grid, gb_grid)"
]
},
{
"cell_type": "code",
"execution_count": 145,
"id": "phantom-nancy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Speedup: 237.58x\n"
]
}
],
"source": [
"print(\"Speedup: {:2.2f}x\".format(746 / 3.14,))"
]
},
{
"cell_type": "markdown",
"id": "compound-domestic",
"metadata": {},
"source": [
"### Compare output\n",
"\n",
"These methods are pretty different, let's check they have the same output"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "unauthorized-garage",
"metadata": {},
"outputs": [],
"source": [
"v_hpx = dm_func_healpix(gl_grid, gb_grid)\n",
"v_sql = dm_func_frbpoppy(gl_grid, gb_grid)"
]
},
{
"cell_type": "code",
"execution_count": 139,
"id": "discrete-plymouth",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7eff1469c150>"
]
},
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x432 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pylab as plt\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"plt.subplot(1,2,1)\n",
"plt.imshow(v_sql.reshape((180, 360)))\n",
"plt.colorbar(orientation='horizontal')\n",
"plt.subplot(1,2,2)\n",
"plt.imshow(v_hpx.reshape((180, 360)))\n",
"plt.colorbar(orientation='horizontal')"
]
},
{
"cell_type": "markdown",
"id": "remarkable-exemption",
"metadata": {},
"source": [
"They look the same, but the two are actually different enough for concern:"
]
},
{
"cell_type": "code",
"execution_count": 148,
"id": "loaded-scratch",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7eff14d4e7d0>"
]
},
"execution_count": 148,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(v_sql.reshape((180, 360)) - v_hpx.reshape((180, 360)))\n",
"plt.colorbar(orientation='horizontal')"
]
},
{
"cell_type": "markdown",
"id": "settled-inspection",
"metadata": {},
"source": [
"## Speeding up `gen_dist()`\n",
"\n",
"Approach here is to use an interpolation function instead of sqlite"
]
},
{
"cell_type": "code",
"execution_count": 218,
"id": "municipal-parcel",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>z</th>\n",
" <th>dist</th>\n",
" <th>vol</th>\n",
" <th>dvol</th>\n",
" <th>cdf_sfr</th>\n",
" <th>cdf_smd</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.00000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.00001</td>\n",
" <td>0.000044</td>\n",
" <td>3.630882e-13</td>\n",
" <td>3.630882e-13</td>\n",
" <td>2.816478e-17</td>\n",
" <td>8.261876e-16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.00002</td>\n",
" <td>0.000089</td>\n",
" <td>2.904685e-12</td>\n",
" <td>2.541597e-12</td>\n",
" <td>2.253219e-16</td>\n",
" <td>6.609439e-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.00003</td>\n",
" <td>0.000133</td>\n",
" <td>9.803245e-12</td>\n",
" <td>6.898559e-12</td>\n",
" <td>7.604724e-16</td>\n",
" <td>2.230666e-14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.00004</td>\n",
" <td>0.000177</td>\n",
" <td>2.323716e-11</td>\n",
" <td>1.343391e-11</td>\n",
" <td>1.802626e-15</td>\n",
" <td>5.287456e-14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>649996</th>\n",
" <td>6.49996</td>\n",
" <td>8.638431</td>\n",
" <td>2.699997e+03</td>\n",
" <td>3.621527e-03</td>\n",
" <td>9.999987e-01</td>\n",
" <td>9.999998e-01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>649997</th>\n",
" <td>6.49997</td>\n",
" <td>8.638435</td>\n",
" <td>2.700000e+03</td>\n",
" <td>3.621523e-03</td>\n",
" <td>9.999991e-01</td>\n",
" <td>9.999999e-01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>649998</th>\n",
" <td>6.49998</td>\n",
" <td>8.638439</td>\n",
" <td>2.700004e+03</td>\n",
" <td>3.621519e-03</td>\n",
" <td>9.999994e-01</td>\n",
" <td>9.999999e-01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>649999</th>\n",
" <td>6.49999</td>\n",
" <td>8.638443</td>\n",
" <td>2.700008e+03</td>\n",
" <td>3.621515e-03</td>\n",
" <td>9.999997e-01</td>\n",
" <td>1.000000e+00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>650000</th>\n",
" <td>6.50000</td>\n",
" <td>8.638447</td>\n",
" <td>2.700011e+03</td>\n",
" <td>3.621511e-03</td>\n",
" <td>1.000000e+00</td>\n",
" <td>1.000000e+00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>650001 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" z dist vol dvol cdf_sfr \\\n",
"0 0.00000 0.000000 0.000000e+00 0.000000e+00 0.000000e+00 \n",
"1 0.00001 0.000044 3.630882e-13 3.630882e-13 2.816478e-17 \n",
"2 0.00002 0.000089 2.904685e-12 2.541597e-12 2.253219e-16 \n",
"3 0.00003 0.000133 9.803245e-12 6.898559e-12 7.604724e-16 \n",
"4 0.00004 0.000177 2.323716e-11 1.343391e-11 1.802626e-15 \n",
"... ... ... ... ... ... \n",
"649996 6.49996 8.638431 2.699997e+03 3.621527e-03 9.999987e-01 \n",
"649997 6.49997 8.638435 2.700000e+03 3.621523e-03 9.999991e-01 \n",
"649998 6.49998 8.638439 2.700004e+03 3.621519e-03 9.999994e-01 \n",
"649999 6.49999 8.638443 2.700008e+03 3.621515e-03 9.999997e-01 \n",
"650000 6.50000 8.638447 2.700011e+03 3.621511e-03 1.000000e+00 \n",
"\n",
" cdf_smd \n",
"0 0.000000e+00 \n",
"1 8.261876e-16 \n",
"2 6.609439e-15 \n",
"3 2.230666e-14 \n",
"4 5.287456e-14 \n",
"... ... \n",
"649996 9.999998e-01 \n",
"649997 9.999999e-01 \n",
"649998 9.999999e-01 \n",
"649999 1.000000e+00 \n",
"650000 1.000000e+00 \n",
"\n",
"[650001 rows x 6 columns]"
]
},
"execution_count": 218,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sqlite3, pandas as pd\n",
"\n",
"db = sqlite3.connect('../data/models/universe/h0-67d74-wm-0d3089-wv-0d6911.db')\n",
"df = pd.read_sql_query(\"SELECT * FROM DISTANCES\", db)\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "accurate-petroleum",
"metadata": {},
"source": [
"### Setup sqlite lookup method"
]
},
{
"cell_type": "code",
"execution_count": 197,
"id": "advanced-chuck",
"metadata": {},
"outputs": [],
"source": [
"dt = pc.DistanceTable(H_0=67.74, W_m=0.3089, W_v=0.6911).lookup\n",
"m = dt(z=np.array([1.0]))\n",
"dist_co_max = m[1]\n",
"vol_co_max = m[2]\n",
"cdf_sfr_max = m[-2]\n",
"cdf_smd_max = m[-1]\n",
"\n",
"def gen_dist_sql(vol_co):\n",
" z, dist_co, = dt(vol_co=vol_co)[:2]\n",
" return z, dist_co"
]
},
{
"cell_type": "markdown",
"id": "stretch-beatles",
"metadata": {},
"source": [
"### Setup interpolation method\n",
"\n",
"(sidenote: I think when using the interpolation method we probably don't need 650k points to be precalculated)"
]
},
{
"cell_type": "code",
"execution_count": 201,
"id": "pressed-baseline",
"metadata": {},
"outputs": [],
"source": [
"from scipy.interpolate import interp1d\n",
"\n",
"vol_co_to_z = interp1d(df['vol'], df['z'])\n",
"vol_co_to_dist = interp1d(df['vol'], df['dist'])\n",
"\n",
"def gen_dist_interp(vol_co):\n",
" z = vol_co_to_z(vol_co)\n",
" dist = vol_co_to_dist(vol_co)\n",
" return z, dist"
]
},
{
"cell_type": "markdown",
"id": "adult-testimony",
"metadata": {},
"source": [
"### Compare timings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "announced-nancy",
"metadata": {},
"outputs": [],
"source": [
"n_gen = 10000\n",
"vol_co = vol_co_max * np.random.random(n_gen)"
]
},
{
"cell_type": "code",
"execution_count": 202,
"id": "based-purpose",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.41 ms ± 23.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit gen_dist_interp(vol_co)"
]
},
{
"cell_type": "code",
"execution_count": 203,
"id": "superb-german",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.2 s ± 49.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%timeit gen_dist_sql(vol_co)"
]
},
{
"cell_type": "markdown",
"id": "regulated-consumer",
"metadata": {},
"source": [
"### Check results are the same"
]
},
{
"cell_type": "code",
"execution_count": 207,
"id": "growing-scott",
"metadata": {},
"outputs": [],
"source": [
"z_interp, dist_co_interp = gen_dist_interp(vol_co)\n",
"z_sql, dist_co_sql = gen_dist_sql(vol_co)"
]
},
{
"cell_type": "code",
"execution_count": 217,
"id": "labeled-supplier",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n"
]
}
],
"source": [
"print(np.allclose(z_interp, z_sql, atol=1e-5))\n",
"print(np.allclose(dist_co_interp, dist_co_sql, atol=1e-4))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment