-
-
Save iurisegtovich/4c6611cdbd3bb35f54bcc18cab8c63b1 to your computer and use it in GitHub Desktop.
Show how to speed up scipy.integrate.odeint simply by decorating the right-hand side with numba's jit function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:127c3bfc204cb24a2df9d516bc2b1a098979cb600df872957d9d75a619ca2b7a" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "heading", | |
"level": 1, | |
"metadata": {}, | |
"source": [ | |
"Speeding up `scipy`'s `odeint` with `numba`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"When the right-hand side of an ODE integration is slow, `scipy`'s `odeint` integration of it will be slow. We can speed up that right-hand side with the [`numba`](http://numba.pydata.org/) package, which compiles python code into machine code via [LLVM](http://llvm.org/) -- which means it's super fast. As we will see, even a very simple ODE can be sped up by a factor of roughly two.\n", | |
"\n", | |
"<br/>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"To install `numba`, I recommend just installing [anaconda](http://continuum.io/downloads), and then doing\n", | |
"```bash\n", | |
"conda install numpy scipy numba\n", | |
"```\n", | |
"Once this finishes, you should be able to run `python -c 'import numba'` without an error (or any output, actually)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"As an example, we'll use the basic example [from the scipy tutorial](http://docs.scipy.org/doc/scipy/reference/tutorial/integrate.html#ordinary-differential-equations-odeint). This integrates the ODE\n", | |
"\\begin{equation*}\n", | |
" \\frac{d^{2}w}{dz^{2}}-zw(z)=0.\n", | |
"\\end{equation*}\n", | |
"Here's the basic python code:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from scipy.integrate import odeint\n", | |
"from scipy.special import gamma, airy\n", | |
"y1_0 = 1.0 / 3**(2.0/3.0) / gamma(2.0/3.0)\n", | |
"y0_0 = -1.0 / 3**(1.0/3.0) / gamma(1.0/3.0)\n", | |
"y0 = [y0_0, y1_0]\n", | |
"t = arange(0, 4.0, 0.01)\n", | |
"\n", | |
"def RHS(y, t):\n", | |
" return [t*y[1],y[0]]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%timeit odeint(RHS, y0, t)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"1000 loops, best of 3: 684 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Not bad speed, though this is just a simple test. More generally, we want to speed up that `RHS` function. First, we'll have to rewrite the function slightly for numba. And for a fair comparison later, we do the same rewrite for python:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def RHS(y, t):\n", | |
" y[0],y[1] = t*y[1],y[0]\n", | |
" return y\n", | |
"\n", | |
"%timeit odeint(RHS, y0, t)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"1000 loops, best of 3: 692 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The rewrite didn't really make any difference. We had to do this because at the moment, numba can't create a new array (without being really slow about it). Fortunately, the first argument to `odeint` is an array that gets thrown away anyway, so we can just replace the values in that array and return it. Of course, we can't set `y` itself to something, because that would be defining a new array; we have to set each component, which is why we use the weird syntax above.\n", | |
"\n", | |
"Now, to benefit from `numba`, just two simple steps are needed:\n", | |
"\n", | |
" 1. Put `from numba import jit` somewhere above the function definition\n", | |
" 2. Put `@jit` on the line immediately before the function definition" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from numba import jit\n", | |
"\n", | |
"@jit\n", | |
"def RHS(y, t):\n", | |
" y[0],y[1] = t*y[1],y[0]\n", | |
" return y\n", | |
"\n", | |
"%timeit odeint(RHS, y0, t)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"1000 loops, best of 3: 387 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Here, the numba-wrapped version is in the neighborhood of twice as fast, even though the function is almost trivial. You can expect larger improvements for more complicated functions. (I've seen speed-ups of factors of 1000!)\n", | |
"\n", | |
"The only catch is that `numba` some times falls into \"object mode\", where it emits code that is essentially the same as python, so it isn't any faster. You can check if this happens by running\n", | |
"```python\n", | |
"RHS_jit.inspect_types()\n", | |
"```\n", | |
"If you see a lot of lines ending with `:: pyobject`, then `numba` isn't helping you, and you have to figure out what's going wrong." | |
] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment