torchdiffeq (Python) vs DifferentialEquations.jl (Julia) ODE Benchmarks (Neural ODE Solvers)

Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks

Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.


Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl

ChrisRackauckas / neural_ode_animation.jl
Created November 25, 2019 00:29
Animation of neural ordinary differential equations with DiffEqFlux.jl
using DiffEqFlux, OrdinaryDiffEq, Flux, Plots
# Generate data from a real ODE
u0 = Float32[2.; 0.]; datasize = 30
tspan = (0.0f0,1.5f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
t = range(tspan[1],tspan[2],length=datasize)
struct DirichletBC{T}
struct BoundaryPaddedArray{T,T2 <: AbstractVector{T}}
macro test_allowed_failure(ex, kws...)
test_expr!("@test_allowed_failure", ex, kws...)
orig_ex = Expr(:inert, ex)
result = Test.get_test_result(ex, __source__)
# code to call do_test with execution result and original expr
:(do_allowedfailure_test($result, $orig_ex))
function do_allowedfailure_test(result::Test.ExecutionResult, orig_expr)
ChrisRackauckas / diffeqflux_blog_examples.jl
Created January 19, 2019 04:25
Example code from the DiffEqFlux.jl blog post
using DifferentialEquations, Flux, DiffEqFlux, Plots
## Solve an ODE
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
using DifferentialEquations, Flux, DiffEqFlux, Plots
## Solve an ODE
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
using Sundials, DiffEqBase
function lorenz(du,u,p,t)
du[1] = 10.0*(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz,u0,tspan)
sol = solve(prob,CVODE_Adams(),reltol=1e-12,abstol=1e-12)
