The Jax developers optimized a differential equation benchmark in this issue which used DiffEqFlux.jl as a performance baseline. The Julia code from there was updated to include some standard performance tricks and is the benchmark code here. Thus both codes have been optimized by the library developers.
Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.
Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl
All are solved at reltol=1e-3, abstol=1e-6
using the fastest ODE solver of the respective package for the given problem.
- SciPy LSODA through odeint takes ~489μs
- SciPy LSODA through odeint with Numba takes ~257μs
- NumbaLSODA takes ~25μs
- DifferentialEquations.jl Rosenbrock23 takes ~9.2μs
using Cassette, DiffRules | |
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot | |
const printbranch = true | |
Cassette.@context HasBranchingCtx | |
function Cassette.overdub(ctx::HasBranchingCtx, f, args...) | |
if Cassette.canrecurse(ctx, f, args...) | |
return Cassette.recurse(ctx, f, args...) |
This example is a 4-dimensional geometric brownian motion. The code
for the torchsde version is pulled directly from the
torchsde README
so that it would be a fair comparison against the author's own code.
The only change to that example is the addition of a dt
choice so that
the simulation method and time step matches between the two different programs.
The SDE is solved 100 times. The summary of the results is as follows:
The spiral neural ODE was used as the training benchmark for both torchdiffeq (Python) and DiffEqFlux (Julia) which utilized the same architecture and 500 steps of ADAM. Both achived similar objective values at the end. Results:
- DiffEqFlux defaults: 7.4 seconds
- DiffEqFlux optimized: 2.7 seconds
- torchdiffeq: 288.965871299999 seconds
using OrdinaryDiffEq, RecursiveArrayTools, LinearAlgebra, Test, SparseArrays, SparseDiffTools, Sundials | |
# Define the constants for the PDE | |
const α₂ = 1.0 | |
const α₃ = 1.0 | |
const β₁ = 1.0 | |
const β₂ = 1.0 | |
const β₃ = 1.0 | |
const r₁ = 1.0 | |
const r₂ = 1.0 |
ERROR: MethodError: Cannot `convert` an object of type Float32 to an object of type Vector{Float32} | |
Closest candidates are: | |
convert(::Type{Array{T, N}}, ::SizedArray{S, T, N, N, Array{T, N}}) where {S, T, N} at C:\Users\accou\.julia\packages\StaticArrays\0T5rI\src\SizedArray.jl:121 | |
convert(::Type{Array{T, N}}, ::SizedArray{S, T, N, M, TData} where {M, TData<:AbstractArray{T, M}}) where {T, S, N} at C:\Users\accou\.julia\packages\StaticArrays\0T5rI\src\SizedArray.jl:115 | |
convert(::Type{<:Array}, ::LabelledArrays.LArray) at C:\Users\accou\.julia\packages\LabelledArrays\lfn1b\src\larray.jl:133 | |
... | |
Stacktrace: | |
[1] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, Vector{Float32}, Float32}, f::Symbol, v::Float32) | |
@ Base .\Base.jl:43 | |
[2] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, Vector{Float32}, Float32}) |
function (ˍ₋out, ˍ₋arg1, ˍ₋arg2, t) | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:349 =# | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:350 =# | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:351 =# | |
begin | |
begin | |
#= C:\Users\accou\.julia\packages\Symbolics\vQXbU\src\build_function.jl:452 =# | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:398 =# @inbounds begin | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:394 =# | |
ˍ₋out[1] = (/)((*)((*)((*)((*)((*)((*)(2, ˍ₋arg2[4]), ˍ₋arg2[2]), ˍ₋arg1[1]), (+)(0.5, (*)(0.5, (tanh)((/)((+)(ˍ₋arg1[14], (/)((*)((*)(ˍ₋arg2[4], ˍ₋arg2[1]), (sqrt)((+)((+)((^)((+)((*)((*)(ˍ₋arg1[9], (sin)(ˍ₋arg1[6])), (sqrt)(ˍ₋arg1[1])), (*)((*)((*)(-1, ˍ₋arg1[10]), (cos)(ˍ₋arg1[6])), (sqrt)(ˍ₋arg1[1]))), 2), (^)((+)((+)((/)((*)((*)(ˍ₋arg1[9], (+)(ˍ₋arg1[2], (*)((+)((+)(2, (*)(ˍ₋arg1[2], (cos)(ˍ₋arg1[6]))), (*)(ˍ₋arg1[3], ( |
<!DOCTYPE html> | |
<HTML lang = "en"> | |
<HEAD> | |
<meta charset="UTF-8"/> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes"> | |
<title>Forward and Reverse Automatic Differentiation In A Nutshell</title> | |
<script type="text/x-mathjax-config"> | |
MathJax.Hub.Config({ |