Skip to content

Instantly share code, notes, and snippets.

@iurisegtovich
Forked from moble/NumbaODEExample.ipynb
Created December 10, 2020 00:42
Show Gist options
  • Save iurisegtovich/4c6611cdbd3bb35f54bcc18cab8c63b1 to your computer and use it in GitHub Desktop.
Save iurisegtovich/4c6611cdbd3bb35f54bcc18cab8c63b1 to your computer and use it in GitHub Desktop.
Show how to speed up scipy.integrate.odeint simply by decorating the right-hand side with numba's jit function
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:127c3bfc204cb24a2df9d516bc2b1a098979cb600df872957d9d75a619ca2b7a"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "heading",
"level": 1,
"metadata": {},
"source": [
"Speeding up `scipy`'s `odeint` with `numba`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the right-hand side of an ODE integration is slow, `scipy`'s `odeint` integration of it will be slow. We can speed up that right-hand side with the [`numba`](http://numba.pydata.org/) package, which compiles python code into machine code via [LLVM](http://llvm.org/) -- which means it's super fast. As we will see, even a very simple ODE can be sped up by a factor of roughly two.\n",
"\n",
"<br/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To install `numba`, I recommend just installing [anaconda](http://continuum.io/downloads), and then doing\n",
"```bash\n",
"conda install numpy scipy numba\n",
"```\n",
"Once this finishes, you should be able to run `python -c 'import numba'` without an error (or any output, actually)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As an example, we'll use the basic example [from the scipy tutorial](http://docs.scipy.org/doc/scipy/reference/tutorial/integrate.html#ordinary-differential-equations-odeint). This integrates the ODE\n",
"\\begin{equation*}\n",
" \\frac{d^{2}w}{dz^{2}}-zw(z)=0.\n",
"\\end{equation*}\n",
"Here's the basic python code:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from scipy.integrate import odeint\n",
"from scipy.special import gamma, airy\n",
"y1_0 = 1.0 / 3**(2.0/3.0) / gamma(2.0/3.0)\n",
"y0_0 = -1.0 / 3**(1.0/3.0) / gamma(1.0/3.0)\n",
"y0 = [y0_0, y1_0]\n",
"t = arange(0, 4.0, 0.01)\n",
"\n",
"def RHS(y, t):\n",
" return [t*y[1],y[0]]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"%timeit odeint(RHS, y0, t)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"1000 loops, best of 3: 684 \u00b5s per loop\n"
]
}
],
"prompt_number": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad speed, though this is just a simple test. More generally, we want to speed up that `RHS` function. First, we'll have to rewrite the function slightly for numba. And for a fair comparison later, we do the same rewrite for python:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def RHS(y, t):\n",
" y[0],y[1] = t*y[1],y[0]\n",
" return y\n",
"\n",
"%timeit odeint(RHS, y0, t)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"1000 loops, best of 3: 692 \u00b5s per loop\n"
]
}
],
"prompt_number": 3
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The rewrite didn't really make any difference. We had to do this because at the moment, numba can't create a new array (without being really slow about it). Fortunately, the first argument to `odeint` is an array that gets thrown away anyway, so we can just replace the values in that array and return it. Of course, we can't set `y` itself to something, because that would be defining a new array; we have to set each component, which is why we use the weird syntax above.\n",
"\n",
"Now, to benefit from `numba`, just two simple steps are needed:\n",
"\n",
" 1. Put `from numba import jit` somewhere above the function definition\n",
" 2. Put `@jit` on the line immediately before the function definition"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from numba import jit\n",
"\n",
"@jit\n",
"def RHS(y, t):\n",
" y[0],y[1] = t*y[1],y[0]\n",
" return y\n",
"\n",
"%timeit odeint(RHS, y0, t)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"1000 loops, best of 3: 387 \u00b5s per loop\n"
]
}
],
"prompt_number": 4
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, the numba-wrapped version is in the neighborhood of twice as fast, even though the function is almost trivial. You can expect larger improvements for more complicated functions. (I've seen speed-ups of factors of 1000!)\n",
"\n",
"The only catch is that `numba` some times falls into \"object mode\", where it emits code that is essentially the same as python, so it isn't any faster. You can check if this happens by running\n",
"```python\n",
"RHS_jit.inspect_types()\n",
"```\n",
"If you see a lot of lines ending with `:: pyobject`, then `numba` isn't helping you, and you have to figure out what's going wrong."
]
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment