Skip to content

Instantly share code, notes, and snippets.

View ChrisRackauckas's full-sized avatar
🎯
Focusing

Christopher Rackauckas ChrisRackauckas

🎯
Focusing
View GitHub Profile
@ChrisRackauckas
ChrisRackauckas / diffeqflux_differentialequations_vs_torchdiffeq_results.md
Last active December 20, 2023 13:10
torchdiffeq (Python) vs DifferentialEquations.jl (Julia) ODE Benchmarks (Neural ODE Solvers)

Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks

Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.

Summary

Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl

@ChrisRackauckas
ChrisRackauckas / neural_ode_animation.jl
Created November 25, 2019 00:29
Animation of neural ordinary differential equations with DiffEqFlux.jl
using DiffEqFlux, OrdinaryDiffEq, Flux, Plots
# Generate data from a real ODE
u0 = Float32[2.; 0.]; datasize = 30
tspan = (0.0f0,1.5f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
struct DirichletBC{T}
l::T
r::T
end
struct BoundaryPaddedArray{T,T2 <: AbstractVector{T}}
l::T
r::T
u::T2
end
macro test_allowed_failure(ex, kws...)
test_expr!("@test_allowed_failure", ex, kws...)
orig_ex = Expr(:inert, ex)
result = Test.get_test_result(ex, __source__)
# code to call do_test with execution result and original expr
:(do_allowedfailure_test($result, $orig_ex))
end
function do_allowedfailure_test(result::Test.ExecutionResult, orig_expr)
@ChrisRackauckas
ChrisRackauckas / diffeqflux_blog_examples.jl
Created January 19, 2019 04:25
Example code from the DiffEqFlux.jl blog post
using DifferentialEquations, Flux, DiffEqFlux, Plots
####################################################
## Solve an ODE
####################################################
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
using DifferentialEquations, Flux, DiffEqFlux, Plots
####################################################
## Solve an ODE
####################################################
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
using Sundials, DiffEqBase
function lorenz(du,u,p,t)
du[1] = 10.0*(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
end
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz,u0,tspan)
sol = solve(prob,CVODE_Adams(),reltol=1e-12,abstol=1e-12)
[[ASTInterpreter2]]
deps = ["DebuggerFramework"]
git-tree-sha1 = "8df3d36e0286777d226f4fd4956a432b73425186"
uuid = "e6d88f4b-b52a-544c-a8d3-7a4f12cb39c3"
version = "0.1.1"
[[AbstractFFTs]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "8d59c3b1463b5e0ad05a3698167f85fac90e184d"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
[deps]
ApproxFun = "28f2ccd6-bb30-5033-b560-165f7b14dc2f"
ArrayInterface = "1f957be0-811d-5197-a864-00d3e8faeebd"
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
We can make this file beautiful and searchable if this error is corrected: Illegal quoting in line 1.
2005860 "Tuple{getfield(Base, Symbol("##644#645")){String, Base.UUID, String}, Base.IOStream}"
160129 "Tuple{getfield(Base, Symbol("##open#294")), Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Function, getfield(Base, Symbol("##644#645")){String, Base.UUID, String}, String}"
247366 "Tuple{typeof(Base.isassigned), Core.SimpleVector, Int64}"
66746 "Tuple{typeof(Compat.Sys.__init__)}"
1306210 "Tuple{typeof(Base.create_expr_cache), String, String, Array{Base.Pair{Base.PkgId, UInt64}, 1}, Base.UUID}"
166863 "Tuple{Type{NamedTuple{(:stderr,), T} where T<:Tuple}, Tuple{Base.TTY}}"
155738 "Tuple{getfield(Base, Symbol("#kw##pipeline")), NamedTuple{(:stderr,), Tuple{Base.TTY}}, typeof(Base.pipeline), Base.Cmd}"
130856 "Tuple{getfield(Base, Symbol("##pipeline#489")), Nothing, Nothing, Base.TTY, Bool, Function, Base.Cmd}"
1075237 "Tuple{typeof(Base.open), Base.CmdRedirect, String, Base.TTY}"
278396 "Tuple{getfield(Base, Symbol("##open#503")), Bool, Bool, Function, Base.CmdRedirect, Base.TTY}"