Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 30 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320 to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320 to your computer and use it in GitHub Desktop.
torchdiffeq (Python) vs DifferentialEquations.jl (Julia) ODE Benchmarks (Neural ODE Solvers)

Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks

Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.

Summary

Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl on every tested problem, and many times substantially slower. Though note that the relative performance of torchdiffeq does increase as the number of ODEs increases.

Additionally, torchdiffeq either exhibited slower gradient calculations or the gradient calculation diverged. For more reasons on why the calculation of torchdiffeq diverges, see this manuscript along with many others detailing the stability of backsolve adjoint methods. Note that when sensealg=BacksolveAdjoint() is used in DifferentialEquations.jl on these problems it similarly diverges, indicating it is an issue with that algorithm (and why the DiffEq documentation does not recommend it this algorithm on such problems!)

Relative Performance of Forward Pass (Lower is Better)

Number of ODEs 3 28 768 3,072 12,288 49,152 196,608 786,432
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x
DifferentialEquations.jl dopri5 1.0x 1.6x 2.8x 2.7x 3.0x 3.0x 3.9x 2.8x
torchdiffeq dopri5 4,900x 190x 840x 220x 82x 31x 24x 17x

Relative Performance of Gradient Pass (Lower is Better)

Number of Parameters 3 4 256 1,024
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x
torchdiffeq dopri5 12,000x 1200x ---- ----

---- means the gradient calculation diverged. To be clear, torchdiffeq did not successfully compute the gradient on any of the PDE experiments. All returned an error due to dt underflow, leading to the experiment being halted.

Benchmark Details

Lorenz Equation (3 ODEs)

Absolute Timings

Forward Pass

  • DifferentialEquations.jl: 1.742 ms
  • SciPy+Numba: 30.8 ms
  • SciPy: 50.2 ms
  • SciPy solve_ivp: 869 ms
  • torchscript torchdiffeq (dopri5): 8.60 seconds

Gradient

  • DifferentialEquations.jl: 4.281 ms
  • torchscript torchdiffeq: 51.9 seconds

Relative Timings (Lower is better)

Forward Pass

  • DifferentialEquations.jl: 1x
  • SciPy+Numba: 18x slower
  • SciPy: 29x slower
  • torchscript torchdiffeq: 4,900x slower

Gradient

  • DifferentialEquations.jl: 1x
  • torchscript torchdiffeq: 12,000x

Pleiades Equation (28 ODEs)

Absolute Timings

Forward Pass

  • DifferentialEquations.jl: 2.118 ms
  • DifferentialEquations.jl DP5: 3.407 ms
  • SciPy+Numba: 2.6 ms
  • SciPy: 13.2 ms
  • torchscript torchdiffeq (dopri5): 405 ms

Gradient

  • DifferentialEquations.jl: 5.501 ms
  • torchscript torchdiffeq: 6.33 seconds

Relative Timings (Lower is better)

Forward Pass

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 1.6x slower
  • SciPy+Numba: 1.2x slower
  • SciPy: 6.2x slower
  • torchscript torchdiffeq (dopri5): 190x slower

Gradient

  • DifferentialEquations.jl: 1x
  • torchscript torchdiffeq: 1200x slower

Non-stiff Reaction Diffusion Equation (N=16) (768 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 3.300 ms
  • DifferentialEquations.jl DP5: 9.135 ms
  • SciPy: 2.2 seconds
  • SciPy+Numba: Failed to compile (numpy.ndarray)
  • torchscript torchdiffeq (dorpi5): 2.78 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.8x slower
  • SciPy: 670x slower
  • torchscript torchdiffeq (dorpi5): 840x slower

Non-stiff Reaction Diffusion Equation (N=32) (3072 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 14.397 ms
  • DifferentialEquations.jl DP5: 38.608 ms
  • SciPy: 6.71 seconds
  • torchscript torchdiffeq (dorpi5): 3.12 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.7x slower
  • SciPy: 460x slower
  • torchscript torchdiffeq (dopri5): 220x slower

Non-stiff Reaction Diffusion Equation (N=64) (12,288 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 64.192 ms
  • DifferentialEquations.jl DP5: 192.216 ms
  • SciPy: 174 seconds
  • torchscript torchdiffeq (dopri5): 5.24 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.0x slower
  • SciPy: 2,700x slower
  • torchscript torchdiffeq (dopri5): 82x slower

Non-stiff Reaction Diffusion Equation (N=128) (49,152 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 299.512 ms
  • DifferentialEquations.jl DP5: 907.863 ms
  • torchscript torchdiffeq (dopri5): 9.29 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.0x slower
  • torchscript torchdiffeq (dopri5): 31x slower

Non-stiff Reaction Diffusion Equation (N=256) (196,608 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 1.586 seconds
  • DifferentialEquations.jl DP5: 6.195 seconds
  • torchscript torchdiffeq (dopri5): 37.5 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.9x slower
  • torchscript torchdiffeq (dopri5): 24x slower

Non-stiff Reaction Diffusion Equation (N=512) (786,432 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 10.3 seconds
  • DifferentialEquations.jl DP5: 29.3 seconds
  • torchscript torchdiffeq (dopri5): 172.59 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.8x slower
  • torchscript torchdiffeq (dopri5): 17x slower

Notes

The torchscript versions are kept as separate scripts to allow for the JITing process to occur, and are called before timing to exclude JIT timing, as per the PyTorch documentation suggestions. Python results were scaled by the number of times ran in timeit. Note that the SciPy timing increase in the reaction-diffusion problem is due to lsoda triggering a BDF stuff and utilizing the Jacobian: with this diffusion coefficient this is unnecessary and leads to a large slowdown.

using OrdinaryDiffEq, StaticArrays, BenchmarkTools, DiffEqSensitivity, ForwardDiff, Zygote
function lorenz_static(u,p,t)
@inbounds begin
dx = p[1]*(u[2]-u[1])
dy = u[1]*(p[2]-u[3]) - u[2]
dz = u[1]*u[2] - p[3]*u[3]
end
@SVector [dx,dy,dz]
end
u0 = @SVector [1.0,0.0,0.0]
p = @SVector [10.0,28.0,8/3]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz_static,u0,tspan,p)
@btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 2.018 ms (31 allocations: 59.25 KiB)
@btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 1.742 ms (35 allocations: 59.17 KiB)
function lorenz!(du,u,p,t)
du[1] = p[1]*(u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
end
prob = ODEProblem(lorenz!,Array(u0),tspan,Array(p))
function fz(p)
mean(solve(prob,DP5(),p=p,saveat=0.1,reltol=1e-8,abstol=1e-8,sensealg=ForwardDiffSensitivity()))
end
@btime Zygote.gradient(fz,ap) # 4.281 ms (8308 allocations: 1.18 MiB)
import numpy as np
from scipy.integrate import odeint
import timeit
import numba
def f(u, t, sigma, rho, beta):
x, y, z = u
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)
sol = odeint(f, u0, t, args=(10.0,28.0,8/3))
def time_func():
odeint(f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.05018504399999983 seconds
numba_f = numba.jit(f,nopython=True)
def time_func():
odeint(numba_f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.03088667100000066 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import torch.nn as nn
import torch.optim as optim
import timeit
@torch.jit.script
class LorenzODE(torch.nn.Module):
def __init__(self):
super(LorenzODE, self).__init__()
self.sigma = nn.Parameter(torch.as_tensor([10.0]))
self.rho = nn.Parameter(torch.as_tensor([28.0]))
self.beta = nn.Parameter(torch.as_tensor([2.66]))
def forward(self, t, u):
x, y, z = u[0],u[1],u[2]
du1 = self.sigma[0] * (y - x)
du2 = x * (self.rho[0] - z) - y
du3 = x * y - self.beta[0] * z
return torch.stack([du1, du2, du3])
u0 = torch.tensor([1.0,0.0,0.0])
t = torch.linspace(0, 100, 1001)
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8)
tmp = LorenzODE()
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=2)/2 # 8.595667100000014 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad()
timeit.Timer(time_grad).timeit(number=2)/2 # 51.85800955000013 seconds
using OrdinaryDiffEq, BenchmarkTools, DiffEqSensitivity, ForwardDiff, Zygote
function f(du,u,p,t)
a = p[1]
b = p[2]
c = p[3]
d = p[4]
@inbounds begin
x = view(u,1:7) # x
y = view(u,8:14) # y
v = view(u,15:21) # x′
w = view(u,22:28) # y′
du[1:7] .= a.*v
du[8:14].= b.*w
for i in 15:28
du[i] = zero(u[1])
end
for i=1:7,j=1:7
if i != j
r = ((x[i]-x[j])^2 + (y[i] - y[j])^2)^(3/2)
du[14+i] += c*j*(x[j] - x[i])/r
du[21+i] += d*j*(y[j] - y[i])/r
end
end
end
end
p = ones(4)
prob = ODEProblem(f,[3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0],(0.0,3.0),p)
@btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 4.060 ms (76 allocations: 20.11 KiB)
@btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 3.407 ms (76 allocations: 18.61 KiB)
@btime solve(prob,VCABM(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 2.118 ms (137 allocations: 36.59 KiB)
function fz(p)
mean(solve(prob,VCABM(),p=p,saveat=0.1,reltol=1e-8,abstol=1e-8,sensealg=ForwardDiffSensitivity()))
end
@btime Zygote.gradient(fz,p) # 5.501 ms (600 allocations: 265.66 KiB)
import numpy as np
from scipy.integrate import odeint
import timeit
import numba
def f(u,t):
x = u[0:6] # x
y = u[7:13] # y
v = u[14:20] # x′
w = u[21:27] # y′
du = np.zeros(28)
du[0:6] = v
du[7:13]= w
for i in range(6):
for j in range(6):
if i != j:
r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2)
du[13+i] += j*(x[j] - x[i])/r
du[20+i] += j*(y[j] - y[i])/r
return du
u0 = np.array([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0])
t = np.linspace(0, 3.0, 31)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.013212921999997889 seconds
numba_f = numba.jit(f,nopython=True)
def time_func():
odeint(numba_f, u0, t, rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.0025926070000059556 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
@torch.jit.script
class PleiadesODE(torch.nn.Module):
def __init__(self):
super(PleiadesODE, self).__init__()
self.a = nn.Parameter(torch.as_tensor([1.0]))
self.b = nn.Parameter(torch.as_tensor([1.0]))
self.c = nn.Parameter(torch.as_tensor([1.0]))
self.d = nn.Parameter(torch.as_tensor([1.0]))
def forward(self, t, u):
x = u[0:6] # x
y = u[7:13] # y
v = u[14:20] # x′
w = u[21:27] # y′
du = torch.zeros(28)
du[0:6] = self.a[0] * v
du[7:13]= self.b[0] * w
for i in range(6):
for j in range(6):
if i != j:
r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2)
du[13+i] += self.c[0] * j*(x[j] - x[i])/r
du[20+i] += self.d[0] * j*(y[j] - y[i])/r
return du
u0 = torch.tensor([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0])
t = torch.linspace(0, 3.0, 31)
tmp = PleiadesODE()
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=10)/10 # 0.40516944999999394 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad()
timeit.Timer(time_grad).timeit(number=2)/2 # 6.3303714999992735 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 128
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 299.512 ms (3263 allocations: 251.83 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 830.273 ms (83 allocations: 10.51 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 907.863 ms (80 allocations: 9.38 MiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁))
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 128
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 9.293160799999896 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 16
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,5.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 3.300 ms (2725 allocations: 4.95 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 8.742 ms (152 allocations: 731.91 KiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 9.135 ms (152 allocations: 712.92 KiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁) # 1.228 s (304531 allocations: 1.37 GiB)
function fz(p)
mean(solve(prob,ROCK4(),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=ForwardSensitivity()))
end
@btime Zygote.gradient(fz,vec(α₁)) # 3.515 s (3413745 allocations: 1.10 GiB)
function fz2(p)
mean(solve(prob,ROCK4(),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=BacksolveAdjoint()))
end
@btime Zygote.gradient(fz2,vec(α₁)) # Diverges!
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 16
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
#A = np.random.rand(N,N)
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
# MyA = My@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
# AMx = A@Mx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 2.2156629000000976 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 2.777176000000054 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad() # AssertionError: underflow in dt 8.289381047241916e-16
timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 8.289381047241916e-16
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 256
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 1.586 s (3239 allocations: 988.71 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 5.285 s (83 allocations: 42.01 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 6.195 s (80 allocations: 37.51 MiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁))
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 256
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 37.48328430000038 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 32
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 14.397 ms (3311 allocations: 16.50 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 37.298 ms (83 allocations: 679.31 KiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 38.608 ms (80 allocations: 606.47 KiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁) # 18.495 s (1124288 allocations: 13.90 GiB)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁)) # 27.777 s (774420 allocations: 447.04 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 32
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
#A = np.random.rand(N,N)
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
# MyA = My@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
# AMx = A@Mx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 6.712808100000984 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 3.120565799999895 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad() # AssertionError: underflow in dt 8.230862228687043e-16
timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 8.230862228687043e-16
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 512
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
function f!(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 10.261 s (3406 allocations: 4.34 GiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 27.946 s (266 allocations: 708.02 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 29.294 s (263 allocations: 690.02 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 512
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 172.58722800000032 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 64
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 64.192 ms (3287 allocations: 64.24 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 180.134 ms (83 allocations: 2.63 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 192.216 ms (80 allocations: 2.35 MiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁))
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 64
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
#A = np.random.rand(N,N)
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
# MyA = My@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
# AMx = A@Mx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 173.78868460000012 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 5.241670299998077 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad() # AssertionError: underflow in dt 7.349905693271135e-16
timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 7.349905693271135e-16
@toollu
Copy link

toollu commented Aug 10, 2020

I'm relatively new to both python and Julia, so bear with me.
First of all thanks for your amazing work on the diffeqflux library, it looks super promising for my purpose.
However, when I run lorenz.py as it is I get the following trace:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/user/SRC/scratch.py", line 39, in <module>
    class LorenzODE(torch.nn.Module):
  File "/opt/anaconda3/envs/DSP/lib/python3.7/site-packages/torch/jit/__init__.py", line 1525, in script
    " pass an instance instead".format(obj))
RuntimeError: Type '<class '__main__.LorenzODE'>' cannot be compiled since it inherits from nn.Module, pass an instance instead.

Without the @torch.jit.script it runs fine however. Is it the code or is it my environment?

@ChrisRackauckas
Copy link
Author

The torch jit has some limitations on how you run the JIT'd code IIRC. I think the easiest one to run was https://gist.github.com/ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320/revisions#diff-0cb4ce6da1b7bafa3e58087c7e9759a6 , so I might just revive that version.

@chacowingnut
Copy link

Howdy, I noticed that you're calling the old odeint, which is deprecated. I tried reproducing your Lorenz Python results using scipy's newer solve_ivp, and achieved a dramatic speed boost. On my machine the forward pass timing improved from 45ish ms to ~1.1 ms. Using DOP853 via solve_ivp's method kwarg I get timings just a hair above half a ms.

import numpy as np
from scipy.integrate import solve_ivp
import timeit

def f(t, u, sigma, rho, beta):
    x, y, z = u
    return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)
sol = solve_ivp(f, t, u0, args=(10.0,28.0,8/3))

def time_func():
    solve_ivp(f, t, u0, args=(10.0,28.0,8/3), rtol=1e-8, atol=1e-8)

print(timeit.Timer(time_func).timeit(number=100)/100)  # 0.0011336688511073589 seconds

@ChrisRackauckas
Copy link
Author

The reason the timings are using odeint is because odeint is known to be a lot faster than solve_ivp. So then why is solve_ivp faster in your example? Well, because it's there's an error in the code 😅. If you check the code you wrote:

import numpy as np
from scipy.integrate import solve_ivp
import timeit

def f(t, u, sigma, rho, beta):
    x, y, z = u
    return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)

sol = solve_ivp(f, t, u0, args=(10.0,28.0,8/3), rtol=1e-8, atol=1e-8, method="RK45")
len(sol.t) # 18

You see that it doesn't act quite the way that you would expect. You asked for the values at t but it gave you 18 values instead of 1001. You need to add the keyword argument t_eval in order to control the output in solve_ivp:

sol = solve_ivp(f, tspan, u0, args=(10.0,28.0,8/3), t_eval = t, rtol=1e-8, atol=1e-8, method="RK45")
len(sol.t) # 1001

and what's the timing with the corrected arguments?

def time_func():
    solve_ivp(f, tspan, u0, args=(10.0,28.0,8/3), t_eval = t, rtol=1e-8, atol=1e-8, method="RK45")

print(timeit.Timer(time_func).timeit(number=100)/100)  # 0.8685612489999994 seconds

DOP853 then is similar:

def time_func():
    solve_ivp(f, tspan, u0, args=(10.0,28.0,8/3), t_eval = t, rtol=1e-8, atol=1e-8, method="DOP853")

print(timeit.Timer(time_func).timeit(number=100)/100)  # 0.5002750910000009 seconds

This particular test is stress testing more of the implementation than the algorithm itself.

@chacowingnut
Copy link

Haha I see. Many thanks for the response and for your contributions in this area!

@ChrisRackauckas
Copy link
Author

No problem, and thanks for your contribution. Noting the solve_ivp time is definitely important.

@roflmaostc
Copy link

Were those tests ever performed on GPUs too?

@ChrisRackauckas
Copy link
Author

These were all CPU. We will be putting out some GPU benchmarks with the new GPU infrastructure next month. The results are pretty much what you'd expect though, with Julia matching the C++ baseline, Jax about 10x slower, and torch pretty far back, because vmap style parallelism isn't an efficient way to handle ODEs. But details should come out next month

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment