Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks
Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.
Summary
Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl on every tested problem, and many times substantially slower. Though note that the relative performance of torchdiffeq does increase as the number of ODEs increases.
Additionally, torchdiffeq either exhibited slower gradient calculations or the gradient calculation
diverged. For more reasons on why the calculation of torchdiffeq diverges, see
this manuscript along with many others detailing the stability
of backsolve adjoint methods. Note that when sensealg=BacksolveAdjoint()
is used in DifferentialEquations.jl
on these problems it similarly diverges, indicating it is an issue with that algorithm (and why the DiffEq
documentation does not recommend it this algorithm on such problems!)
Relative Performance of Forward Pass (Lower is Better)
Number of ODEs | 3 | 28 | 768 | 3,072 | 12,288 | 49,152 | 196,608 | 786,432 |
---|---|---|---|---|---|---|---|---|
DifferentialEquations.jl | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x |
DifferentialEquations.jl dopri5 | 1.0x | 1.6x | 2.8x | 2.7x | 3.0x | 3.0x | 3.9x | 2.8x |
torchdiffeq dopri5 | 4,900x | 190x | 840x | 220x | 82x | 31x | 24x | 17x |
Relative Performance of Gradient Pass (Lower is Better)
Number of Parameters | 3 | 4 | 256 | 1,024 |
---|---|---|---|---|
DifferentialEquations.jl | 1.0x | 1.0x | 1.0x | 1.0x |
torchdiffeq dopri5 | 12,000x | 1200x | ---- | ---- |
----
means the gradient calculation diverged. To be clear, torchdiffeq did not successfully
compute the gradient on any of the PDE experiments. All returned an error due to dt
underflow,
leading to the experiment being halted.
Benchmark Details
Lorenz Equation (3 ODEs)
Absolute Timings
Forward Pass
- DifferentialEquations.jl: 1.742 ms
- SciPy+Numba: 30.8 ms
- SciPy: 50.2 ms
- SciPy
solve_ivp
: 869 ms - torchscript torchdiffeq (dopri5): 8.60 seconds
Gradient
- DifferentialEquations.jl: 4.281 ms
- torchscript torchdiffeq: 51.9 seconds
Relative Timings (Lower is better)
Forward Pass
- DifferentialEquations.jl: 1x
- SciPy+Numba: 18x slower
- SciPy: 29x slower
- torchscript torchdiffeq: 4,900x slower
Gradient
- DifferentialEquations.jl: 1x
- torchscript torchdiffeq: 12,000x
Pleiades Equation (28 ODEs)
Absolute Timings
Forward Pass
- DifferentialEquations.jl: 2.118 ms
- DifferentialEquations.jl DP5: 3.407 ms
- SciPy+Numba: 2.6 ms
- SciPy: 13.2 ms
- torchscript torchdiffeq (dopri5): 405 ms
Gradient
- DifferentialEquations.jl: 5.501 ms
- torchscript torchdiffeq: 6.33 seconds
Relative Timings (Lower is better)
Forward Pass
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 1.6x slower
- SciPy+Numba: 1.2x slower
- SciPy: 6.2x slower
- torchscript torchdiffeq (dopri5): 190x slower
Gradient
- DifferentialEquations.jl: 1x
- torchscript torchdiffeq: 1200x slower
Non-stiff Reaction Diffusion Equation (N=16) (768 ODEs)
Absolute Timings
- DifferentialEquations.jl: 3.300 ms
- DifferentialEquations.jl DP5: 9.135 ms
- SciPy: 2.2 seconds
- SciPy+Numba: Failed to compile (numpy.ndarray)
- torchscript torchdiffeq (dorpi5): 2.78 seconds
Relative Timings (Lower is better)
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 2.8x slower
- SciPy: 670x slower
- torchscript torchdiffeq (dorpi5): 840x slower
Non-stiff Reaction Diffusion Equation (N=32) (3072 ODEs)
Absolute Timings
- DifferentialEquations.jl: 14.397 ms
- DifferentialEquations.jl DP5: 38.608 ms
- SciPy: 6.71 seconds
- torchscript torchdiffeq (dorpi5): 3.12 seconds
Relative Timings (Lower is better)
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 2.7x slower
- SciPy: 460x slower
- torchscript torchdiffeq (dopri5): 220x slower
Non-stiff Reaction Diffusion Equation (N=64) (12,288 ODEs)
Absolute Timings
- DifferentialEquations.jl: 64.192 ms
- DifferentialEquations.jl DP5: 192.216 ms
- SciPy: 174 seconds
- torchscript torchdiffeq (dopri5): 5.24 seconds
Relative Timings (Lower is better)
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 3.0x slower
- SciPy: 2,700x slower
- torchscript torchdiffeq (dopri5): 82x slower
Non-stiff Reaction Diffusion Equation (N=128) (49,152 ODEs)
Absolute Timings
- DifferentialEquations.jl: 299.512 ms
- DifferentialEquations.jl DP5: 907.863 ms
- torchscript torchdiffeq (dopri5): 9.29 seconds
Relative Timings (Lower is better)
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 3.0x slower
- torchscript torchdiffeq (dopri5): 31x slower
Non-stiff Reaction Diffusion Equation (N=256) (196,608 ODEs)
Absolute Timings
- DifferentialEquations.jl: 1.586 seconds
- DifferentialEquations.jl DP5: 6.195 seconds
- torchscript torchdiffeq (dopri5): 37.5 seconds
Relative Timings (Lower is better)
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 3.9x slower
- torchscript torchdiffeq (dopri5): 24x slower
Non-stiff Reaction Diffusion Equation (N=512) (786,432 ODEs)
Absolute Timings
- DifferentialEquations.jl: 10.3 seconds
- DifferentialEquations.jl DP5: 29.3 seconds
- torchscript torchdiffeq (dopri5): 172.59 seconds
Relative Timings (Lower is better)
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 2.8x slower
- torchscript torchdiffeq (dopri5): 17x slower
Notes
The torchscript versions are kept as separate scripts to allow for the JITing process to occur, and are called before timing to exclude JIT timing, as per the PyTorch documentation suggestions. Python results were scaled by the number of times ran in timeit. Note that the SciPy timing increase in the reaction-diffusion problem is due to lsoda triggering a BDF stuff and utilizing the Jacobian: with this diffusion coefficient this is unnecessary and leads to a large slowdown.
I'm relatively new to both python and Julia, so bear with me.
First of all thanks for your amazing work on the diffeqflux library, it looks super promising for my purpose.
However, when I run lorenz.py as it is I get the following trace:
Without the @torch.jit.script it runs fine however. Is it the code or is it my environment?