Skip to content

Instantly share code, notes, and snippets.

Last active July 6, 2021 18:32
Show Gist options
  • Save moble/3aa44230256b66956587 to your computer and use it in GitHub Desktop.
Save moble/3aa44230256b66956587 to your computer and use it in GitHub Desktop.
Show how to speed up scipy.integrate.odeint simply by decorating the right-hand side with numba's jit function
Display the source blob
Display the rendered blob
"metadata": {
"name": "",
"signature": "sha256:127c3bfc204cb24a2df9d516bc2b1a098979cb600df872957d9d75a619ca2b7a"
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
"cells": [
"cell_type": "heading",
"level": 1,
"metadata": {},
"source": [
"Speeding up `scipy`'s `odeint` with `numba`"
"cell_type": "markdown",
"metadata": {},
"source": [
"When the right-hand side of an ODE integration is slow, `scipy`'s `odeint` integration of it will be slow. We can speed up that right-hand side with the [`numba`]( package, which compiles python code into machine code via [LLVM]( -- which means it's super fast. As we will see, even a very simple ODE can be sped up by a factor of roughly two.\n",
"cell_type": "markdown",
"metadata": {},
"source": [
"To install `numba`, I recommend just installing [anaconda](, and then doing\n",
"conda install numpy scipy numba\n",
"Once this finishes, you should be able to run `python -c 'import numba'` without an error (or any output, actually)."
"cell_type": "markdown",
"metadata": {},
"source": [
"As an example, we'll use the basic example [from the scipy tutorial]( This integrates the ODE\n",
" \\frac{d^{2}w}{dz^{2}}-zw(z)=0.\n",
"Here's the basic python code:"
"cell_type": "code",
"collapsed": false,
"input": [
"from scipy.integrate import odeint\n",
"from scipy.special import gamma, airy\n",
"y1_0 = 1.0 / 3**(2.0/3.0) / gamma(2.0/3.0)\n",
"y0_0 = -1.0 / 3**(1.0/3.0) / gamma(1.0/3.0)\n",
"y0 = [y0_0, y1_0]\n",
"t = arange(0, 4.0, 0.01)\n",
"def RHS(y, t):\n",
" return [t*y[1],y[0]]"
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
"cell_type": "code",
"collapsed": false,
"input": [
"%timeit odeint(RHS, y0, t)"
"language": "python",
"metadata": {},
"outputs": [
"output_type": "stream",
"stream": "stdout",
"text": [
"1000 loops, best of 3: 684 \u00b5s per loop\n"
"prompt_number": 2
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad speed, though this is just a simple test. More generally, we want to speed up that `RHS` function. First, we'll have to rewrite the function slightly for numba. And for a fair comparison later, we do the same rewrite for python:"
"cell_type": "code",
"collapsed": false,
"input": [
"def RHS(y, t):\n",
" y[0],y[1] = t*y[1],y[0]\n",
" return y\n",
"%timeit odeint(RHS, y0, t)"
"language": "python",
"metadata": {},
"outputs": [
"output_type": "stream",
"stream": "stdout",
"text": [
"1000 loops, best of 3: 692 \u00b5s per loop\n"
"prompt_number": 3
"cell_type": "markdown",
"metadata": {},
"source": [
"The rewrite didn't really make any difference. We had to do this because at the moment, numba can't create a new array (without being really slow about it). Fortunately, the first argument to `odeint` is an array that gets thrown away anyway, so we can just replace the values in that array and return it. Of course, we can't set `y` itself to something, because that would be defining a new array; we have to set each component, which is why we use the weird syntax above.\n",
"Now, to benefit from `numba`, just two simple steps are needed:\n",
" 1. Put `from numba import jit` somewhere above the function definition\n",
" 2. Put `@jit` on the line immediately before the function definition"
"cell_type": "code",
"collapsed": false,
"input": [
"from numba import jit\n",
"def RHS(y, t):\n",
" y[0],y[1] = t*y[1],y[0]\n",
" return y\n",
"%timeit odeint(RHS, y0, t)"
"language": "python",
"metadata": {},
"outputs": [
"output_type": "stream",
"stream": "stdout",
"text": [
"1000 loops, best of 3: 387 \u00b5s per loop\n"
"prompt_number": 4
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, the numba-wrapped version is in the neighborhood of twice as fast, even though the function is almost trivial. You can expect larger improvements for more complicated functions. (I've seen speed-ups of factors of 1000!)\n",
"The only catch is that `numba` some times falls into \"object mode\", where it emits code that is essentially the same as python, so it isn't any faster. You can check if this happens by running\n",
"If you see a lot of lines ending with `:: pyobject`, then `numba` isn't helping you, and you have to figure out what's going wrong."
"metadata": {}
Copy link

ekpuz commented Oct 30, 2014

I'm relatively new to Python and Numba, but when I tried to run this example and compare the results of the integration, I get different values from the modified code. I wonder if odeint doesn't rely on the y0 passed to it to not change due to the evaluation of the RHS. The jit speedup is remarkable, but if I haven't made a mistake I don't think the resulting answer is correct.

Copy link

KaBrrrp commented Mar 10, 2015

Does this also work with integrate.ode (not integrate.odeint) ? I use ode.set_f_params to set some parameters for calculating the RHS, but then it fails with some TypeError complaining that the required argument is not found. Any ideas?

Update: also fails without setting extra parameters. I assume set_initial_value is enough to screw up numba functionality

Copy link

IMHO, a much simpler solution that "always works" is to use the C-based integrators in PyDSTool or on github. It includes automatic code generation to turn your string-based declarations of ODEs into actual C code (no python callbacks) and executes much faster. You just need to have gcc/gfortran and SWIG installed, which is easy with the conda package manager (among other solutions).

Copy link

@KaBrrrp, did you find a solution for integrate.ode?

Copy link

bobzwik commented Aug 4, 2019

Hello @moble, I've tried to implement numba with my ode function, and I'm having trouble because I am passing a dictionary and an custom object as arguments in the odeint function. I've never used numba before, but I tried your code and some simple for loops to understand the gist of it, but I've hit a wall.

Link to my question on stackoverflow

Copy link

fccoelho commented Dec 9, 2020

IMHO, a much simpler solution that "always works" is to use the C-based integrators in PyDSTool or on github. It includes automatic code generation to turn your string-based declarations of ODEs into actual C code (no python callbacks) and executes much faster. You just need to have gcc/gfortran and SWIG installed, which is easy with the conda package manager (among other solutions).

Yes, but to implement the same model in PyDSTool you need to write much more code. "Simple is better than complex"

Copy link

this gist is old (2014) but still relevant, imo.

i currently am not able to not use ode with that same error "TypeError: not enough arguments: expected 2, got 1"

note there is a suggestion to bypass the arg error using a python wrapper function to the jited function < >

i was able to use jit with odeint, as mentioned here, but i had problem with the following statement:

" Fortunately, the first argument to odeint is an array that gets thrown away anyway, so we can just replace the values in that array and return it."

in that case I had different results (absurd results) when messing with the original y memory,
instead i made a scratch variable preallocated outside to work the dy

scratch_dN = np.zeros((5,)) #scratch memory allocated externally to dNi, reused between calls
sol = odeint(dNi, Ni0, t, args=(scratch_dN,) )

where dNi had been defined as:

def dNi(N,t,scratch_dN): #!!    
    vector_dNi = scratch_dN #rename reuse memory
    vector_dNi[0] = ...
    ...#up to vector_dNi0[5]
    return vector_dNi  

Copy link

moble commented Mar 24, 2021

@iurisegtovich Yeah, I think something has changed internally with scipy. Also note that odeint and ode are actually considered the "Old API" now; the new one uses solve_ivp and friends — but these examples about 1,000 times slower with the new API!!! (I can imagine it's all in the overhead.)

Anyway, your method is a good way to go. Numba also supports jitclass now, so you could also pass a more complicated object as one of the args, with all sorts of fancy capabilities.

But maybe more importantly, python itself has sped up significantly, so that even using the naive approach in this notebook gives nearly the same speed as when using numba. Obviously, really complicated functions will still benefit from numba, but in this example numba is actually a little bit slower in my tests.

Copy link

Nicholaswogan commented Jul 6, 2021

I wrote a wrapper to LSODA which has no overhead: . During an ODE solve, the python interpreter is never used, so everything is fast for small problems:

from NumbaLSODA import lsoda_sig, lsoda
import numba as nb

def RHS_nb(t, y, dy, p):
    dy[0],dy[1] = t*y[1],y[0]

funcptr = RHS_nb.address

def test():
    sol, success = lsoda(funcptr, y0_, t)
y0_ = np.array(y0)
%timeit test()

result is

26.2 µs ± 342 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment