Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
torchdiffeq (Python) vs DifferentialEquations.jl (Julia) ODE Benchmarks (Neural ODE Solvers)

Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks

Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.

Summary

Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl on every tested problem, and many times substantially slower. Though note that the relative performance of torchdiffeq does increase as the number of ODEs increases.

Number of ODEs 3 28 768 3,072 12,288 49,152 196,608 786,432
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x
DifferentialEquations.jl dopri5 1.0x 1.9x 2.9x 2.9x 2.9x 2.9x 3.4x 2.6x
torchdiffeq dopri5 5,850x 1700x 420x 280x 120x 31x 41x 38x
torchdiffeq adams 7,600x 1100x 710x 490x 170x 44x 47x 43x

Benchmark Details

Lorenz Equation (3 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 1.675 ms
  • SciPy+Numba: 43.6 ms
  • SciPy: 78.4 ms
  • torchscript torchdiffeq (dopri5): 9.8 seconds
  • torchscript torchdiffeq (adams): 12.8 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • SciPy+Numba: 26x slower
  • SciPy: 47x slower
  • torchdiffeq: 5,850x slower
  • torchscript torchdiffeq: 7,600x slower

Pleiades Equation (28 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 4.052 ms
  • DifferentialEquations.jl DP5: 7.697 ms
  • SciPy+Numba: 5.20 ms
  • SciPy: 206 ms
  • torchscript torchdiffeq (dopri5): 6.7 seconds
  • torchscript torchdiffeq (adams): 4.5 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 1.9x slower
  • SciPy+Numba: 1.3x slower
  • SciPy: 51x slower
  • torchscript torchdiffeq (dopri5): 1700x slower
  • torchscript torchdiffeq (adams): 1100x slower

Non-stiff Reaction Diffusion Equation (N=16) (768 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 4.815 ms
  • DifferentialEquations.jl DP5: 14.386 ms
  • SciPy: 1.3 seconds
  • SciPy+Numba: Failed to compile (numpy.ndarray)
  • torchscript torchdiffeq (dorpi5): 2.0 seconds
  • torchscript torchdiffeq (adams): 3.4 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.0x slower
  • SciPy: 270x slower
  • torchscript torchdiffeq (dorpi5): 420x slower
  • torchscript torchdiffeq (adams): 710x slower

Non-stiff Reaction Diffusion Equation (N=32) (3072 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 15.890 ms
  • DifferentialEquations.jl DP5: 45.871 ms
  • SciPy: 5.8 seconds
  • torchscript torchdiffeq (dorpi5): 4.5 seconds
  • torchscript torchdiffeq (adams): 7.8 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.9x slower
  • SciPy: 370x slower
  • torchscript torchdiffeq (dopri5): 280x slower
  • torchscript torchdiffeq (adams): 490x slower

Non-stiff Reaction Diffusion Equation (N=64) (12,288 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 65.113 ms
  • DifferentialEquations.jl DP5: 187.491 ms
  • SciPy: 176 seconds
  • torchscript torchdiffeq (dopri5): 7.6 seconds
  • torchscript torchdiffeq (adams): 11.3 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.9x slower
  • SciPy: 41,000x slower
  • torchscript torchdiffeq (dopri5): 120x slower
  • torchscript torchdiffeq (adams): 170x slower

Non-stiff Reaction Diffusion Equation (N=128) (49,152 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 299.065 ms
  • DifferentialEquations.jl DP5: 865.763 ms
  • torchscript torchdiffeq (dopri5): 9.2 seconds
  • torchscript torchdiffeq (adams): 13.1 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.9x slower
  • torchscript torchdiffeq (dopri5): 31x slower
  • torchscript torchdiffeq (adams): 44x slower

Non-stiff Reaction Diffusion Equation (N=256) (196,608 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 1.9 seconds
  • DifferentialEquations.jl DP5: 6.5 seconds
  • torchscript torchdiffeq (dopri5): 78 seconds
  • torchscript torchdiffeq (adams): 90 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.4x slower
  • torchscript torchdiffeq (dopri5): 41x slower
  • torchscript torchdiffeq (adams): 47x slower

Non-stiff Reaction Diffusion Equation (N=512) (786,432 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 12.3 seconds
  • DifferentialEquations.jl DP5: 32.5 seconds
  • torchscript torchdiffeq (dopri5): 462 seconds
  • torchscript torchdiffeq (adams): 523 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.6x slower
  • torchscript torchdiffeq (dopri5): 38x slower
  • torchscript torchdiffeq (adams): 43x slower

Notes

The torchscript versions are kept as separate scripts to allow for the JITing process to occur, and are called before timing to exclude JIT timing, as per the PyTorch documentation suggestions. Python results were scaled by the number of times ran in timeit. Note that the SciPy timing increase in the reaction-diffusion problem is due to lsoda triggering a BDF stuff and utilizing the Jacobian: with this diffusion coefficient this is unnecessary and leads to a large slowdown.

using OrdinaryDiffEq, StaticArrays, BenchmarkTools
function lorenz_static(u,p,t)
@inbounds begin
dx = p[1]*(u[2]-u[1])
dy = u[1]*(p[2]-u[3]) - u[2]
dz = u[1]*u[2] - p[3]*u[3]
end
@SVector [dx,dy,dz]
end
u0 = @SVector [1.0,0.0,0.0]
p = @SVector [10.0,28.0,8/3]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz_static,u0,tspan,p)
@btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 1.879 ms (56 allocations: 60.19 KiB)
@btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 1.675 ms (56 allocations: 59.83 KiB)
import numpy as np
from scipy.integrate import odeint
import timeit
import numba
def f(u, t, sigma, rho, beta):
x, y, z = u
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)
sol = odeint(f, u0, t, args=(10.0,28.0,8/3))
def time_func():
odeint(f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.07844269799999892 seconds
numba_f = numba.jit(f,nopython=True)
def time_func():
odeint(numba_f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.04360099399998944 seconds
import torch
from torchdiffeq import odeint
import timeit
@torch.jit.script
class LorenzODE(torch.nn.Module):
def __init__(self):
super(LorenzODE, self).__init__()
def forward(self, t, u):
x, y, z = u[0],u[1],u[2]
du1 = 10.0 * (y - x)
du2 = x * (28.0 - z) - y
du3 = x * y - 2.66 * z
return torch.stack([du1, du2, du3])
u0 = torch.tensor([1.0,0.0,0.0])
t = torch.linspace(0, 100, 1001)
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=2)/2 # 9.803650500000003 seconds
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8, method = "adams")
def time_func():
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8, method = "adams")
time_func()
timeit.Timer(time_func).timeit(number=2)/2 # 12.757906950000688 seconds
using OrdinaryDiffEq, BenchmarkTools
function f(du,u,p,t)
@inbounds begin
x = view(u,1:7) # x
y = view(u,8:14) # y
v = view(u,15:21) # x′
w = view(u,22:28) # y′
du[1:7] .= v
du[8:14].= w
for i in 15:28
du[i] = zero(u[1])
end
for i=1:7,j=1:7
if i != j
r = ((x[i]-x[j])^2 + (y[i] - y[j])^2)^(3/2)
du[14+i] += j*(x[j] - x[i])/r
du[21+i] += j*(y[j] - y[i])/r
end
end
end
end
prob = ODEProblem(f,[3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0],(0.0,3.0))
@btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 9.142 ms (11814 allocations: 568.22 KiB)
@btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 7.697 ms (9963 allocations: 480.41 KiB)
@btime solve(prob,VCABM(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 4.052 ms (4644 allocations: 247.47 KiB)
import numpy as np
from scipy.integrate import odeint
import timeit
import numba
def f(u,t):
x = u[0:6] # x
y = u[7:13] # y
v = u[14:20] # x′
w = u[21:27] # y′
du = np.zeros(28)
du[0:6] = v
du[7:13]= w
for i in range(6):
for j in range(6):
if i != j:
r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2)
du[13+i] += j*(x[j] - x[i])/r
du[20+i] += j*(y[j] - y[i])/r
return du
u0 = np.array([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0])
t = np.linspace(0, 3.0, 31)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.022143833999998606 seconds
numba_f = numba.jit(f,nopython=True)
def time_func():
odeint(numba_f, u0, t, rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.0023944810000102734 seconds
import torch
from torchdiffeq import odeint
import timeit
@torch.jit.script
class PleiadesODE(torch.nn.Module):
def __init__(self):
super(PleiadesODE, self).__init__()
def forward(self, t, u):
x = u[0:6] # x
y = u[7:13] # y
v = u[14:20] # x′
w = u[21:27] # y′
du = torch.zeros(28)
du[0:6] = v
du[7:13]= w
for i in range(6):
for j in range(6):
if i != j:
r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2)
du[13+i] += j*(x[j] - x[i])/r
du[20+i] += j*(y[j] - y[i])/r
return du
u0 = torch.tensor([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0])
t = torch.linspace(0, 3.0, 31)
odeint(PleiadesODE(), u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
odeint(PleiadesODE(), u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=10)/10 # 0.3721496899999693 seconds
odeint(PleiadesODE(), u0, t, rtol = 1e-8, atol=1e-8, method = "adams")
def time_func():
odeint(PleiadesODE(), u0, t, rtol = 1e-8, atol=1e-8, method = "adams")
time_func()
timeit.Timer(time_func).timeit(number=10)/10 # 0.45268095999999786 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 128
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
u0 = zeros(N,N,3)
prob = ODEProblem(f,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 299.065 ms (16323 allocations: 286.39 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 848.662 ms (42054 allocations: 46.82 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 865.763 ms (44388 allocations: 45.83 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 128
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
@torch.jit.script
def f_script(t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = tMy@A
AMx = A@tMx
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tu0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(f_script, tu0, t)
sol = odeint(f_script, tu0, t, method="adams")
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 9.22206560000086 seconds
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8, method="adams")
time_func()
timeit.Timer(time_func).timeit(number=1) # 13.137924599999678 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 16
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
u0 = zeros(N,N,3)
prob = ODEProblem(f,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 4.815 ms (16148 allocations: 5.78 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 13.726 ms (42080 allocations: 3.27 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 14.386 ms (44381 allocations: 3.39 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 16
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = My@A
AMx = A@Mx
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = scipy.integrate.odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
scipy.integrate.odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 1.2923503000001801 seconds
import torch
from torchdiffeq import odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
@torch.jit.script
def f_script(t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = tMy@A
AMx = A@tMx
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(f_script, u0, t)
sol = odeint(f_script, u0, t, method="adams")
def time_func():
odeint(f_script, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 1.9997992999997223 seconds
def time_func():
odeint(f_script, u0, t, rtol = 1e-8, atol=1e-8, method="adams")
time_func()
timeit.Timer(time_func).timeit(number=1) # 3.446258099998886 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 256
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
u0 = zeros(N,N,3)
prob = ODEProblem(f,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 1.904 s (16137 allocations: 1.10 GiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 6.486 s (42054 allocations: 179.57 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 5.884 s (44388 allocations: 175.21 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 256
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
@torch.jit.script
def f_script(t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = tMy@A
AMx = A@tMx
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tu0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(f_script, tu0, t)
sol = odeint(f_script, tu0, t, method="adams")
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 78.06950359999973 seconds
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8, method="adams")
time_func()
timeit.Timer(time_func).timeit(number=1) # 89.57403810000142 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 32
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
u0 = zeros(N,N,3)
prob = ODEProblem(f,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 15.890 ms (16737 allocations: 19.43 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 43.184 ms (42162 allocations: 5.34 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 45.871 ms (44460 allocations: 5.41 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 32
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = My@A
AMx = A@Mx
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = scipy.integrate.odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
scipy.integrate.odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 5.794173000000228 seconds
import numba
numba_f = numba.jit(f,nopython=True)
numba_f(u0,0.0) # incompatible with numpy.ndarray
import torch
from torchdiffeq import odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
@torch.jit.script
def f_script(t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = tMy@A
AMx = A@tMx
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(f_script, u0, t)
sol = odeint(f_script, u0, t, method="adams")
def time_func():
odeint(f_script, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 4.5048727000003055 seconds
def time_func():
odeint(f_script, u0, t, rtol = 1e-8, atol=1e-8, method="adams")
time_func()
timeit.Timer(time_func).timeit(number=1) # 7.796025600000576 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 512
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
u0 = zeros(N,N,3)
prob = ODEProblem(f,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 12.289 s (16001 allocations: 4.34 GiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 30.663 s (42018 allocations: 710.56 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 32.455 s (44388 allocations: 692.71 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 512
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
@torch.jit.script
def f_script(t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = tMy@A
AMx = A@tMx
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tu0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(f_script, tu0, t)
sol = odeint(f_script, tu0, t, method="adams")
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 462.25758789999963 seconds
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8, method="adams")
time_func()
timeit.Timer(time_func).timeit(number=1) # 522.7533516000003 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 64
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
u0 = zeros(N,N,3)
prob = ODEProblem(f,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 65.113 ms (16527 allocations: 73.49 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 182.185 ms (42126 allocations: 13.63 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.1); # 187.491 ms (44424 allocations: 13.49 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 64
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = My@A
AMx = A@Mx
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = scipy.integrate.odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
scipy.integrate.odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 176.5848661 seconds
import numba
numba_f = numba.jit(f,nopython=True)
numba_f(u0,0.0) # incompatible with numpy.ndarray
import torch
from torchdiffeq import odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
@torch.jit.script
def f_script(t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
MyA = tMy@A
AMx = A@tMx
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tu0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(f_script, tu0, t)
sol = odeint(f_script, tu0, t, method="adams")
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 7.557895300000382 seconds
def time_func():
odeint(f_script, tu0, t, rtol = 1e-8, atol=1e-8, method="adams")
time_func()
timeit.Timer(time_func).timeit(number=1) # 11.257385000000795 seconds
@Triceert2

This comment has been minimized.

Copy link

Triceert2 commented Aug 10, 2020

I'm relatively new to both python and Julia, so bear with me.
First of all thanks for your amazing work on the diffeqflux library, it looks super promising for my purpose.
However, when I run lorenz.py as it is I get the following trace:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/user/SRC/scratch.py", line 39, in <module>
    class LorenzODE(torch.nn.Module):
  File "/opt/anaconda3/envs/DSP/lib/python3.7/site-packages/torch/jit/__init__.py", line 1525, in script
    " pass an instance instead".format(obj))
RuntimeError: Type '<class '__main__.LorenzODE'>' cannot be compiled since it inherits from nn.Module, pass an instance instead.

Without the @torch.jit.script it runs fine however. Is it the code or is it my environment?

@ChrisRackauckas

This comment has been minimized.

Copy link
Owner Author

ChrisRackauckas commented Aug 10, 2020

The torch jit has some limitations on how you run the JIT'd code IIRC. I think the easiest one to run was https://gist.github.com/ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320/revisions#diff-0cb4ce6da1b7bafa3e58087c7e9759a6 , so I might just revive that version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.