Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 30 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320 to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320 to your computer and use it in GitHub Desktop.
torchdiffeq (Python) vs DifferentialEquations.jl (Julia) ODE Benchmarks (Neural ODE Solvers)

Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks

Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.

Summary

Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl on every tested problem, and many times substantially slower. Though note that the relative performance of torchdiffeq does increase as the number of ODEs increases.

Additionally, torchdiffeq either exhibited slower gradient calculations or the gradient calculation diverged. For more reasons on why the calculation of torchdiffeq diverges, see this manuscript along with many others detailing the stability of backsolve adjoint methods. Note that when sensealg=BacksolveAdjoint() is used in DifferentialEquations.jl on these problems it similarly diverges, indicating it is an issue with that algorithm (and why the DiffEq documentation does not recommend it this algorithm on such problems!)

Relative Performance of Forward Pass (Lower is Better)

Number of ODEs 3 28 768 3,072 12,288 49,152 196,608 786,432
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x
DifferentialEquations.jl dopri5 1.0x 1.6x 2.8x 2.7x 3.0x 3.0x 3.9x 2.8x
torchdiffeq dopri5 4,900x 190x 840x 220x 82x 31x 24x 17x

Relative Performance of Gradient Pass (Lower is Better)

Number of Parameters 3 4 256 1,024
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x
torchdiffeq dopri5 12,000x 1200x ---- ----

---- means the gradient calculation diverged. To be clear, torchdiffeq did not successfully compute the gradient on any of the PDE experiments. All returned an error due to dt underflow, leading to the experiment being halted.

Benchmark Details

Lorenz Equation (3 ODEs)

Absolute Timings

Forward Pass

  • DifferentialEquations.jl: 1.742 ms
  • SciPy+Numba: 30.8 ms
  • SciPy: 50.2 ms
  • SciPy solve_ivp: 869 ms
  • torchscript torchdiffeq (dopri5): 8.60 seconds

Gradient

  • DifferentialEquations.jl: 4.281 ms
  • torchscript torchdiffeq: 51.9 seconds

Relative Timings (Lower is better)

Forward Pass

  • DifferentialEquations.jl: 1x
  • SciPy+Numba: 18x slower
  • SciPy: 29x slower
  • torchscript torchdiffeq: 4,900x slower

Gradient

  • DifferentialEquations.jl: 1x
  • torchscript torchdiffeq: 12,000x

Pleiades Equation (28 ODEs)

Absolute Timings

Forward Pass

  • DifferentialEquations.jl: 2.118 ms
  • DifferentialEquations.jl DP5: 3.407 ms
  • SciPy+Numba: 2.6 ms
  • SciPy: 13.2 ms
  • torchscript torchdiffeq (dopri5): 405 ms

Gradient

  • DifferentialEquations.jl: 5.501 ms
  • torchscript torchdiffeq: 6.33 seconds

Relative Timings (Lower is better)

Forward Pass

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 1.6x slower
  • SciPy+Numba: 1.2x slower
  • SciPy: 6.2x slower
  • torchscript torchdiffeq (dopri5): 190x slower

Gradient

  • DifferentialEquations.jl: 1x
  • torchscript torchdiffeq: 1200x slower

Non-stiff Reaction Diffusion Equation (N=16) (768 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 3.300 ms
  • DifferentialEquations.jl DP5: 9.135 ms
  • SciPy: 2.2 seconds
  • SciPy+Numba: Failed to compile (numpy.ndarray)
  • torchscript torchdiffeq (dorpi5): 2.78 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.8x slower
  • SciPy: 670x slower
  • torchscript torchdiffeq (dorpi5): 840x slower

Non-stiff Reaction Diffusion Equation (N=32) (3072 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 14.397 ms
  • DifferentialEquations.jl DP5: 38.608 ms
  • SciPy: 6.71 seconds
  • torchscript torchdiffeq (dorpi5): 3.12 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.7x slower
  • SciPy: 460x slower
  • torchscript torchdiffeq (dopri5): 220x slower

Non-stiff Reaction Diffusion Equation (N=64) (12,288 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 64.192 ms
  • DifferentialEquations.jl DP5: 192.216 ms
  • SciPy: 174 seconds
  • torchscript torchdiffeq (dopri5): 5.24 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.0x slower
  • SciPy: 2,700x slower
  • torchscript torchdiffeq (dopri5): 82x slower

Non-stiff Reaction Diffusion Equation (N=128) (49,152 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 299.512 ms
  • DifferentialEquations.jl DP5: 907.863 ms
  • torchscript torchdiffeq (dopri5): 9.29 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.0x slower
  • torchscript torchdiffeq (dopri5): 31x slower

Non-stiff Reaction Diffusion Equation (N=256) (196,608 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 1.586 seconds
  • DifferentialEquations.jl DP5: 6.195 seconds
  • torchscript torchdiffeq (dopri5): 37.5 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 3.9x slower
  • torchscript torchdiffeq (dopri5): 24x slower

Non-stiff Reaction Diffusion Equation (N=512) (786,432 ODEs)

Absolute Timings

  • DifferentialEquations.jl: 10.3 seconds
  • DifferentialEquations.jl DP5: 29.3 seconds
  • torchscript torchdiffeq (dopri5): 172.59 seconds

Relative Timings (Lower is better)

  • DifferentialEquations.jl: 1x
  • DifferentialEquations.jl DP5: 2.8x slower
  • torchscript torchdiffeq (dopri5): 17x slower

Notes

The torchscript versions are kept as separate scripts to allow for the JITing process to occur, and are called before timing to exclude JIT timing, as per the PyTorch documentation suggestions. Python results were scaled by the number of times ran in timeit. Note that the SciPy timing increase in the reaction-diffusion problem is due to lsoda triggering a BDF stuff and utilizing the Jacobian: with this diffusion coefficient this is unnecessary and leads to a large slowdown.

using OrdinaryDiffEq, StaticArrays, BenchmarkTools, DiffEqSensitivity, ForwardDiff, Zygote
function lorenz_static(u,p,t)
@inbounds begin
dx = p[1]*(u[2]-u[1])
dy = u[1]*(p[2]-u[3]) - u[2]
dz = u[1]*u[2] - p[3]*u[3]
end
@SVector [dx,dy,dz]
end
u0 = @SVector [1.0,0.0,0.0]
p = @SVector [10.0,28.0,8/3]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz_static,u0,tspan,p)
@btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 2.018 ms (31 allocations: 59.25 KiB)
@btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 1.742 ms (35 allocations: 59.17 KiB)
function lorenz!(du,u,p,t)
du[1] = p[1]*(u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
end
prob = ODEProblem(lorenz!,Array(u0),tspan,Array(p))
function fz(p)
mean(solve(prob,DP5(),p=p,saveat=0.1,reltol=1e-8,abstol=1e-8,sensealg=ForwardDiffSensitivity()))
end
@btime Zygote.gradient(fz,ap) # 4.281 ms (8308 allocations: 1.18 MiB)
import numpy as np
from scipy.integrate import odeint
import timeit
import numba
def f(u, t, sigma, rho, beta):
x, y, z = u
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)
sol = odeint(f, u0, t, args=(10.0,28.0,8/3))
def time_func():
odeint(f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.05018504399999983 seconds
numba_f = numba.jit(f,nopython=True)
def time_func():
odeint(numba_f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.03088667100000066 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import torch.nn as nn
import torch.optim as optim
import timeit
@torch.jit.script
class LorenzODE(torch.nn.Module):
def __init__(self):
super(LorenzODE, self).__init__()
self.sigma = nn.Parameter(torch.as_tensor([10.0]))
self.rho = nn.Parameter(torch.as_tensor([28.0]))
self.beta = nn.Parameter(torch.as_tensor([2.66]))
def forward(self, t, u):
x, y, z = u[0],u[1],u[2]
du1 = self.sigma[0] * (y - x)
du2 = x * (self.rho[0] - z) - y
du3 = x * y - self.beta[0] * z
return torch.stack([du1, du2, du3])
u0 = torch.tensor([1.0,0.0,0.0])
t = torch.linspace(0, 100, 1001)
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8)
tmp = LorenzODE()
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=2)/2 # 8.595667100000014 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad()
timeit.Timer(time_grad).timeit(number=2)/2 # 51.85800955000013 seconds
using OrdinaryDiffEq, BenchmarkTools, DiffEqSensitivity, ForwardDiff, Zygote
function f(du,u,p,t)
a = p[1]
b = p[2]
c = p[3]
d = p[4]
@inbounds begin
x = view(u,1:7) # x
y = view(u,8:14) # y
v = view(u,15:21) # x′
w = view(u,22:28) # y′
du[1:7] .= a.*v
du[8:14].= b.*w
for i in 15:28
du[i] = zero(u[1])
end
for i=1:7,j=1:7
if i != j
r = ((x[i]-x[j])^2 + (y[i] - y[j])^2)^(3/2)
du[14+i] += c*j*(x[j] - x[i])/r
du[21+i] += d*j*(y[j] - y[i])/r
end
end
end
end
p = ones(4)
prob = ODEProblem(f,[3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0],(0.0,3.0),p)
@btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 4.060 ms (76 allocations: 20.11 KiB)
@btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 3.407 ms (76 allocations: 18.61 KiB)
@btime solve(prob,VCABM(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 2.118 ms (137 allocations: 36.59 KiB)
function fz(p)
mean(solve(prob,VCABM(),p=p,saveat=0.1,reltol=1e-8,abstol=1e-8,sensealg=ForwardDiffSensitivity()))
end
@btime Zygote.gradient(fz,p) # 5.501 ms (600 allocations: 265.66 KiB)
import numpy as np
from scipy.integrate import odeint
import timeit
import numba
def f(u,t):
x = u[0:6] # x
y = u[7:13] # y
v = u[14:20] # x′
w = u[21:27] # y′
du = np.zeros(28)
du[0:6] = v
du[7:13]= w
for i in range(6):
for j in range(6):
if i != j:
r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2)
du[13+i] += j*(x[j] - x[i])/r
du[20+i] += j*(y[j] - y[i])/r
return du
u0 = np.array([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0])
t = np.linspace(0, 3.0, 31)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.013212921999997889 seconds
numba_f = numba.jit(f,nopython=True)
def time_func():
odeint(numba_f, u0, t, rtol = 1e-8, atol=1e-8)
timeit.Timer(time_func).timeit(number=100)/100 # 0.0025926070000059556 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
@torch.jit.script
class PleiadesODE(torch.nn.Module):
def __init__(self):
super(PleiadesODE, self).__init__()
self.a = nn.Parameter(torch.as_tensor([1.0]))
self.b = nn.Parameter(torch.as_tensor([1.0]))
self.c = nn.Parameter(torch.as_tensor([1.0]))
self.d = nn.Parameter(torch.as_tensor([1.0]))
def forward(self, t, u):
x = u[0:6] # x
y = u[7:13] # y
v = u[14:20] # x′
w = u[21:27] # y′
du = torch.zeros(28)
du[0:6] = self.a[0] * v
du[7:13]= self.b[0] * w
for i in range(6):
for j in range(6):
if i != j:
r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2)
du[13+i] += self.c[0] * j*(x[j] - x[i])/r
du[20+i] += self.d[0] * j*(y[j] - y[i])/r
return du
u0 = torch.tensor([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0])
t = torch.linspace(0, 3.0, 31)
tmp = PleiadesODE()
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=10)/10 # 0.40516944999999394 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad()
timeit.Timer(time_grad).timeit(number=2)/2 # 6.3303714999992735 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 128
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 299.512 ms (3263 allocations: 251.83 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 830.273 ms (83 allocations: 10.51 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 907.863 ms (80 allocations: 9.38 MiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁))
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 128
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 9.293160799999896 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 16
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,5.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 3.300 ms (2725 allocations: 4.95 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 8.742 ms (152 allocations: 731.91 KiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 9.135 ms (152 allocations: 712.92 KiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁) # 1.228 s (304531 allocations: 1.37 GiB)
function fz(p)
mean(solve(prob,ROCK4(),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=ForwardSensitivity()))
end
@btime Zygote.gradient(fz,vec(α₁)) # 3.515 s (3413745 allocations: 1.10 GiB)
function fz2(p)
mean(solve(prob,ROCK4(),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=BacksolveAdjoint()))
end
@btime Zygote.gradient(fz2,vec(α₁)) # Diverges!
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 16
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
#A = np.random.rand(N,N)
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
# MyA = My@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
# AMx = A@Mx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 2.2156629000000976 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 2.777176000000054 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad() # AssertionError: underflow in dt 8.289381047241916e-16
timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 8.289381047241916e-16
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 256
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 1.586 s (3239 allocations: 988.71 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 5.285 s (83 allocations: 42.01 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 6.195 s (80 allocations: 37.51 MiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁))
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 256
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 37.48328430000038 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 32
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 14.397 ms (3311 allocations: 16.50 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 37.298 ms (83 allocations: 679.31 KiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 38.608 ms (80 allocations: 606.47 KiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁) # 18.495 s (1124288 allocations: 13.90 GiB)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁)) # 27.777 s (774420 allocations: 447.04 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 32
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
#A = np.random.rand(N,N)
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
# MyA = My@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
# AMx = A@Mx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 6.712808100000984 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 3.120565799999895 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad() # AssertionError: underflow in dt 8.230862228687043e-16
timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 8.230862228687043e-16
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 512
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
const α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
function f!(du,u,p,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0))
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 10.261 s (3406 allocations: 4.34 GiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 27.946 s (266 allocations: 708.02 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 29.294 s (263 allocations: 690.02 MiB)
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 512
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 172.58722800000032 seconds
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools
using LoopVectorization
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const D = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 64
const X = reshape([i for i in 1:N for j in 1:N],N,N)
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
α₁ = 1.0.*(X.>=4*N/5)
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
const My = copy(Mx)
Mx[2,1] = 2.0
Mx[end-1,end] = 2.0
My[1,2] = 2.0
My[end,end-1] = 2.0
#=
# Define the discretized PDE as an ODE function
const MyA = zeros(N,N)
const AMx = zeros(N,N)
const DA = zeros(N,N)
function f(du,u,α₁,t)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
mul!(MyA,My,A)
mul!(AMx,A,Mx)
@. DA = D*(MyA + AMx)
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
end
=#
function f!(_du,_u,_α₁,t)
u = reshape(_u,N,N,3)
du = reshape(_du,N,N,3)
A = @view u[:,:,1]
B = @view u[:,:,2]
C = @view u[:,:,3]
dA = @view du[:,:,1]
dB = @view du[:,:,2]
dC = @view du[:,:,3]
α₁ = reshape(_α₁,N,N)
@inbounds for j in 2:N-1, i in 2:N-1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = 1
dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for j in 2:N-1
i = N
dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = 1
dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds for i in 2:N-1
j = N
dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
@inbounds begin
i = 1; j = 1
dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = 1; j = N
dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = 1
dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
i = N; j = N
dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) +
α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j]
dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j]
end
end
#=
# Double check:
u = rand(N,N,3)
du = similar(u)
du2 = similar(u)
f( du ,u,nothing,0.0)
f!(du2,u,nothing,0.0)
=#
u0 = zeros(N,N,3)
prob = ODEProblem(f!,u0,(0.0,10.0),α₁)
@btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 64.192 ms (3287 allocations: 64.24 MiB)
@btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 180.134 ms (83 allocations: 2.63 MiB)
@btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 192.216 ms (80 allocations: 2.35 MiB)
function f(p)
mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8))
end
@btime ForwardDiff.gradient(f,α₁)
const EIGEN_EST = Ref(0.0f0)
EIGEN_EST[] = maximum(abs, eigvals(Matrix(My)))
function fz(p)
mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
@btime Zygote.gradient(fz,vec(α₁))
import numpy as np
from scipy.integrate import odeint
a2 = 1.0
a3 = 1.0
b1 = 1.0
b2 = 1.0
b3 = 1.0
r1 = 1.0
r2 = 1.0
_DD = 100.0
g1 = 0.1
g2 = 0.1
g3 = 0.1
N = 64
X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N))
Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N))
a1 = 1.0*(X>=4*N/5)
Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1)
Mx[1,0] = 2.0
Mx[N-2,N-1] = 2.0
My[0,1] = 2.0
My[N-1,N-2] = 2.0
u0 = np.ndarray.flatten(np.zeros((3,N,N)))
#A = np.random.rand(N,N)
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
# Define the discretized PDE as an ODE function
def f(_u,t):
u = np.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
# MyA = My@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
# AMx = A@Mx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + a1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return np.ndarray.flatten(np.concatenate([dA,dB,dC]))
tspan = (0., 10.)
t = np.linspace(0, 10, 101)
sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
import timeit
def time_func():
odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000)
timeit.Timer(time_func).timeit(number=1) # 173.78868460000012 seconds
import torch
from torchdiffeq import odeint_adjoint as odeint
import timeit
tMx = torch.from_numpy(Mx)
tMy = torch.from_numpy(My)
tX = torch.from_numpy(X)
tY = torch.from_numpy(Y)
ta1 = torch.from_numpy(a1)
#A = torch.from_numpy(np.random.rand(N,N))
#top = -2*A[0,:] + 2*A[1,:]
#bottom = 2*A[N-2,:] - 2*A[N-1,:]
#torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero!
#left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
#right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
#torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero!
@torch.jit.script
class ReactionDiffusionODE(torch.nn.Module):
def __init__(self):
super(ReactionDiffusionODE, self).__init__()
self.ta1 = nn.Parameter(ta1)
def forward(self, t, _u):
u = torch.reshape(_u,(3,N,N))
A = u[0,:,:]
B = u[1,:,:]
C = u[2,:,:]
#MyA = tMy@A
top = -2*A[0,:] + 2*A[1,:]
bottom = 2*A[N-2,:] - 2*A[N-1,:]
MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom))
#AMx = A@tMx
left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1)
right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1)
AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right))
DA = _DD*(MyA + AMx)
dA = DA + ta1 - b1*A - r1*A*B + r2*C
dB = a2 - b2*B - r1*A*B + r2*C
dC = a3 - b3*C + r1*A*B - r2*C
return torch.flatten(torch.cat([dA,dB,dC]))
tmp = ReactionDiffusionODE()
u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64))
t = torch.linspace(0, 10, 101)
sol = odeint(tmp, u0, t)
def time_func():
with torch.no_grad():
odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
time_func()
timeit.Timer(time_func).timeit(number=1) # 5.241670299998077 seconds
optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3)
def time_grad():
optimizer.zero_grad()
out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8)
loss = torch.mean(out)
loss.backward()
time_grad() # AssertionError: underflow in dt 7.349905693271135e-16
timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 7.349905693271135e-16
@ChrisRackauckas
Copy link
Author

No problem, and thanks for your contribution. Noting the solve_ivp time is definitely important.

@roflmaostc
Copy link

Were those tests ever performed on GPUs too?

@ChrisRackauckas
Copy link
Author

These were all CPU. We will be putting out some GPU benchmarks with the new GPU infrastructure next month. The results are pretty much what you'd expect though, with Julia matching the C++ baseline, Jax about 10x slower, and torch pretty far back, because vmap style parallelism isn't an efficient way to handle ODEs. But details should come out next month

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment