Skip to content

Instantly share code, notes, and snippets.

@samuelstjean
Last active April 8, 2017 07:36
Show Gist options
  • Save samuelstjean/79958a07900bf417e5a610bfff9929fd to your computer and use it in GitHub Desktop.
Save samuelstjean/79958a07900bf417e5a610bfff9929fd to your computer and use it in GitHub Desktop.
benchmark gsl hyp1f1 vs scipy hyp1f1
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.19.0\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"from itertools import repeat, product, chain\n",
"\n",
"from scipy.special import hyp1f1 as sci_1f1\n",
"import scipy \n",
"\n",
"if scipy.__version__ < '0.17':\n",
" print('You have scipy {}, but versions before 0.17 are known to have issues with hyp1f1'.format(scipy.__version__))\n",
"\n",
"from mpmath import hyp1f1 as mp_1f1\n",
"\n",
"%load_ext Cython\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"print(scipy.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## You need cythongsl installed for this to run"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"%%cython -lgsl -lgslcblas\n",
"\n",
"cimport cython\n",
"cimport numpy\n",
"\n",
"from cython_gsl cimport gsl_sf_hyperg_1F1\n",
"\n",
"@cython.boundscheck(False)\n",
"@cython.wraparound(False)\n",
"@cython.cdivision(True)\n",
"def gsl_1f1(a, b, x):\n",
" \"\"\"Wrapper for 1F1 hypergeometric series function\n",
" http://en.wikipedia.org/wiki/Confluent_hypergeometric_function\"\"\"\n",
" \n",
" return gsl_sf_hyperg_1F1(a, b, x)\n",
"\n",
"# cdef double _gsl1f1(double a, int b, double x):\n",
"# return gsl_sf_hyperg_1F1(a, b, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## Big for loop checker\n",
"\n",
"We now check for various SNR when there is a discrepency between gsl hyp1f1 and scipy hyp1f1.\n",
"\n",
"If a difference higher than tol is found, the result is checked against mpmath to find the lowest precision one.\n",
"\n",
"Unfortunately, the culprit always seems to be scipy for now, assuming mpmath does things right since it should not have any precision issues by design"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"+ scipy definitely does **not** work before 0.17 (it overflows early on and carry a lot of numerical error), and still has precision issues in 0.19"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/samuel/anaconda3/lib/python3.5/site-packages/ipykernel/__main__.py:10: DeprecationWarning: object of type <class 'float'> cannot be safely interpreted as an integer.\n"
]
},
{
"ename": "AssertionError",
"evalue": "gsl = 63.7982891195326 scipy = 63.689692119208395 diff = 0.1085970003242025\n diff_mpmath = 3.44653732636813e-12, 0.108597000327649\n (a,b,c) = (-0.5, 61, -247207.56154023242)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-84a60b747be6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 18\u001b[0m assert(np.abs(resgsl-resscipy) < tol), \"gsl = {} scipy = {} diff = {}\\n diff_mpmath = {}, {}\\n (a,b,c) = ({}, {}, {})\" .format(resgsl, resscipy, np.abs(resgsl-resscipy), \n\u001b[1;32m 19\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresgsl\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmp_1f1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresscipy\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmp_1f1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m a,b,c)\n\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m: gsl = 63.7982891195326 scipy = 63.689692119208395 diff = 0.1085970003242025\n diff_mpmath = 3.44653732636813e-12, 0.108597000327649\n (a,b,c) = (-0.5, 61, -247207.56154023242)"
]
}
],
"source": [
"A = np.array([-0.5]) # always 0.5 or -0.5\n",
"B = np.arange(1, 64) # 0.5 * number of degrees of freedom (number of coils N)\n",
"C = lambda snr: -(snr**2)/4 # snr dependent\n",
"tol = 1e-1 # 1e-3 makes some values fail at N > 32, problems always arise at 1e-6 for N = 12 \n",
"# and at very high SNR even gsl starts to 'only' reach 1e-12 precision.\n",
"\n",
"start = 0\n",
"end = 1000 # SNR = signal / std, so it is crazy high in the b0 near csf/partial volume, which is a problematic region for hyp1f1\n",
"spacing = 0.1\n",
"SNRs = np.linspace(start, end, num=end/spacing)\n",
"\n",
"for a, b, snr in chain(product(A, B, SNRs), product(-A, B+1, SNRs)):\n",
"\n",
" c = C(snr)\n",
" resgsl = gsl_1f1(a,b,c)\n",
" resscipy = sci_1f1(a,b,c)\n",
" \n",
" assert(np.abs(resgsl-resscipy) < tol), \"gsl = {} scipy = {} diff = {}\\n diff_mpmath = {}, {}\\n (a,b,c) = ({}, {}, {})\" \\\n",
" .format(resgsl, resscipy, np.abs(resgsl-resscipy), \n",
" np.abs(resgsl - mp_1f1(float(a),float(b),float(c))), np.abs(resscipy - mp_1f1(float(a),float(b),float(c))),\n",
" a,b,c)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## To check some particular values"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n",
"5.28396681964693e-8\n",
"The slowest run took 10.73 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 513 ns per loop\n",
"The slowest run took 8.32 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 4.07 µs per loop\n",
"1000 loops, best of 3: 431 µs per loop\n"
]
}
],
"source": [
"(a,b,c) = (-0.5, 15, -77.54335667852396)\n",
"\n",
"print(np.abs(gsl_1f1(a,b,c) - mp_1f1(a,b,c)))\n",
"print(np.abs(sci_1f1(a,b,c) - mp_1f1(a,b,c)))\n",
"\n",
"%timeit gsl_1f1(a,b,c)\n",
"%timeit sci_1f1(a,b,c)\n",
"%timeit mp_1f1(a,b,c)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment