{{ message }}

Instantly share code, notes, and snippets.

# ChrisRackauckas/diffeqflux_differentialequations_vs_torchdiffeq_results.md

Last active Jun 25, 2022
torchdiffeq (Python) vs DifferentialEquations.jl (Julia) ODE Benchmarks (Neural ODE Solvers)

# Torchdiffeq vs DifferentialEquations.jl (/ DiffEqFlux.jl) Neural ODE Compatible Solver Benchmarks

Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.

## Summary

Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl on every tested problem, and many times substantially slower. Though note that the relative performance of torchdiffeq does increase as the number of ODEs increases.

Additionally, torchdiffeq either exhibited slower gradient calculations or the gradient calculation diverged. For more reasons on why the calculation of torchdiffeq diverges, see this manuscript along with many others detailing the stability of backsolve adjoint methods. Note that when `sensealg=BacksolveAdjoint()` is used in DifferentialEquations.jl on these problems it similarly diverges, indicating it is an issue with that algorithm (and why the DiffEq documentation does not recommend it this algorithm on such problems!)

### Relative Performance of Forward Pass (Lower is Better)

Number of ODEs 3 28 768 3,072 12,288 49,152 196,608 786,432
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x 1.0x
DifferentialEquations.jl dopri5 1.0x 1.6x 2.8x 2.7x 3.0x 3.0x 3.9x 2.8x
torchdiffeq dopri5 4,900x 190x 840x 220x 82x 31x 24x 17x

### Relative Performance of Gradient Pass (Lower is Better)

Number of Parameters 3 4 256 1,024
DifferentialEquations.jl 1.0x 1.0x 1.0x 1.0x
torchdiffeq dopri5 12,000x 1200x ---- ----

`----` means the gradient calculation diverged. To be clear, torchdiffeq did not successfully compute the gradient on any of the PDE experiments. All returned an error due to `dt` underflow, leading to the experiment being halted.

# Benchmark Details

## Lorenz Equation (3 ODEs)

### Absolute Timings

#### Forward Pass

• DifferentialEquations.jl: 1.742 ms
• SciPy+Numba: 30.8 ms
• SciPy: 50.2 ms
• SciPy `solve_ivp`: 869 ms
• torchscript torchdiffeq (dopri5): 8.60 seconds

• DifferentialEquations.jl: 4.281 ms
• torchscript torchdiffeq: 51.9 seconds

### Relative Timings (Lower is better)

#### Forward Pass

• DifferentialEquations.jl: 1x
• SciPy+Numba: 18x slower
• SciPy: 29x slower
• torchscript torchdiffeq: 4,900x slower

• DifferentialEquations.jl: 1x
• torchscript torchdiffeq: 12,000x

### Absolute Timings

#### Forward Pass

• DifferentialEquations.jl: 2.118 ms
• DifferentialEquations.jl DP5: 3.407 ms
• SciPy+Numba: 2.6 ms
• SciPy: 13.2 ms
• torchscript torchdiffeq (dopri5): 405 ms

• DifferentialEquations.jl: 5.501 ms
• torchscript torchdiffeq: 6.33 seconds

### Relative Timings (Lower is better)

#### Forward Pass

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 1.6x slower
• SciPy+Numba: 1.2x slower
• SciPy: 6.2x slower
• torchscript torchdiffeq (dopri5): 190x slower

• DifferentialEquations.jl: 1x
• torchscript torchdiffeq: 1200x slower

## Non-stiff Reaction Diffusion Equation (N=16) (768 ODEs)

### Absolute Timings

• DifferentialEquations.jl: 3.300 ms
• DifferentialEquations.jl DP5: 9.135 ms
• SciPy: 2.2 seconds
• SciPy+Numba: Failed to compile (numpy.ndarray)
• torchscript torchdiffeq (dorpi5): 2.78 seconds

### Relative Timings (Lower is better)

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 2.8x slower
• SciPy: 670x slower
• torchscript torchdiffeq (dorpi5): 840x slower

## Non-stiff Reaction Diffusion Equation (N=32) (3072 ODEs)

### Absolute Timings

• DifferentialEquations.jl: 14.397 ms
• DifferentialEquations.jl DP5: 38.608 ms
• SciPy: 6.71 seconds
• torchscript torchdiffeq (dorpi5): 3.12 seconds

### Relative Timings (Lower is better)

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 2.7x slower
• SciPy: 460x slower
• torchscript torchdiffeq (dopri5): 220x slower

## Non-stiff Reaction Diffusion Equation (N=64) (12,288 ODEs)

### Absolute Timings

• DifferentialEquations.jl: 64.192 ms
• DifferentialEquations.jl DP5: 192.216 ms
• SciPy: 174 seconds
• torchscript torchdiffeq (dopri5): 5.24 seconds

### Relative Timings (Lower is better)

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 3.0x slower
• SciPy: 2,700x slower
• torchscript torchdiffeq (dopri5): 82x slower

## Non-stiff Reaction Diffusion Equation (N=128) (49,152 ODEs)

### Absolute Timings

• DifferentialEquations.jl: 299.512 ms
• DifferentialEquations.jl DP5: 907.863 ms
• torchscript torchdiffeq (dopri5): 9.29 seconds

### Relative Timings (Lower is better)

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 3.0x slower
• torchscript torchdiffeq (dopri5): 31x slower

## Non-stiff Reaction Diffusion Equation (N=256) (196,608 ODEs)

### Absolute Timings

• DifferentialEquations.jl: 1.586 seconds
• DifferentialEquations.jl DP5: 6.195 seconds
• torchscript torchdiffeq (dopri5): 37.5 seconds

### Relative Timings (Lower is better)

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 3.9x slower
• torchscript torchdiffeq (dopri5): 24x slower

## Non-stiff Reaction Diffusion Equation (N=512) (786,432 ODEs)

### Absolute Timings

• DifferentialEquations.jl: 10.3 seconds
• DifferentialEquations.jl DP5: 29.3 seconds
• torchscript torchdiffeq (dopri5): 172.59 seconds

### Relative Timings (Lower is better)

• DifferentialEquations.jl: 1x
• DifferentialEquations.jl DP5: 2.8x slower
• torchscript torchdiffeq (dopri5): 17x slower

## Notes

The torchscript versions are kept as separate scripts to allow for the JITing process to occur, and are called before timing to exclude JIT timing, as per the PyTorch documentation suggestions. Python results were scaled by the number of times ran in timeit. Note that the SciPy timing increase in the reaction-diffusion problem is due to lsoda triggering a BDF stuff and utilizing the Jacobian: with this diffusion coefficient this is unnecessary and leads to a large slowdown.

This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, StaticArrays, BenchmarkTools, DiffEqSensitivity, ForwardDiff, Zygote function lorenz_static(u,p,t) @inbounds begin dx = p*(u-u) dy = u*(p-u) - u dz = u*u - p*u end @SVector [dx,dy,dz] end u0 = @SVector [1.0,0.0,0.0] p = @SVector [10.0,28.0,8/3] tspan = (0.0,100.0) prob = ODEProblem(lorenz_static,u0,tspan,p) @btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 2.018 ms (31 allocations: 59.25 KiB) @btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 1.742 ms (35 allocations: 59.17 KiB) function lorenz!(du,u,p,t) du = p*(u-u) du = u*(p-u) - u du = u*u - p*u end prob = ODEProblem(lorenz!,Array(u0),tspan,Array(p)) function fz(p) mean(solve(prob,DP5(),p=p,saveat=0.1,reltol=1e-8,abstol=1e-8,sensealg=ForwardDiffSensitivity())) end @btime Zygote.gradient(fz,ap) # 4.281 ms (8308 allocations: 1.18 MiB)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint import timeit import numba def f(u, t, sigma, rho, beta): x, y, z = u return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z] u0 = [1.0,0.0,0.0] tspan = (0., 100.) t = np.linspace(0, 100, 1001) sol = odeint(f, u0, t, args=(10.0,28.0,8/3)) def time_func(): odeint(f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8) timeit.Timer(time_func).timeit(number=100)/100 # 0.05018504399999983 seconds numba_f = numba.jit(f,nopython=True) def time_func(): odeint(numba_f, u0, t, args=(10.0,28.0,8/3),rtol = 1e-8, atol=1e-8) timeit.Timer(time_func).timeit(number=100)/100 # 0.03088667100000066 seconds import torch from torchdiffeq import odeint_adjoint as odeint import torch.nn as nn import torch.optim as optim import timeit @torch.jit.script class LorenzODE(torch.nn.Module): def __init__(self): super(LorenzODE, self).__init__() self.sigma = nn.Parameter(torch.as_tensor([10.0])) self.rho = nn.Parameter(torch.as_tensor([28.0])) self.beta = nn.Parameter(torch.as_tensor([2.66])) def forward(self, t, u): x, y, z = u,u,u du1 = self.sigma * (y - x) du2 = x * (self.rho - z) - y du3 = x * y - self.beta * z return torch.stack([du1, du2, du3]) u0 = torch.tensor([1.0,0.0,0.0]) t = torch.linspace(0, 100, 1001) odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8) tmp = LorenzODE() def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=2)/2 # 8.595667100000014 seconds optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3) def time_grad(): optimizer.zero_grad() out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) loss = torch.mean(out) loss.backward() time_grad() timeit.Timer(time_grad).timeit(number=2)/2 # 51.85800955000013 seconds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, BenchmarkTools, DiffEqSensitivity, ForwardDiff, Zygote function f(du,u,p,t) a = p b = p c = p d = p @inbounds begin x = view(u,1:7) # x y = view(u,8:14) # y v = view(u,15:21) # x′ w = view(u,22:28) # y′ du[1:7] .= a.*v du[8:14].= b.*w for i in 15:28 du[i] = zero(u) end for i=1:7,j=1:7 if i != j r = ((x[i]-x[j])^2 + (y[i] - y[j])^2)^(3/2) du[14+i] += c*j*(x[j] - x[i])/r du[21+i] += d*j*(y[j] - y[i])/r end end end end p = ones(4) prob = ODEProblem(f,[3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0],(0.0,3.0),p) @btime solve(prob,Tsit5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 4.060 ms (76 allocations: 20.11 KiB) @btime solve(prob,DP5(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 3.407 ms (76 allocations: 18.61 KiB) @btime solve(prob,VCABM(),saveat=0.1,reltol=1e-8,abstol=1e-8) # 2.118 ms (137 allocations: 36.59 KiB) function fz(p) mean(solve(prob,VCABM(),p=p,saveat=0.1,reltol=1e-8,abstol=1e-8,sensealg=ForwardDiffSensitivity())) end @btime Zygote.gradient(fz,p) # 5.501 ms (600 allocations: 265.66 KiB)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint import timeit import numba def f(u,t): x = u[0:6] # x y = u[7:13] # y v = u[14:20] # x′ w = u[21:27] # y′ du = np.zeros(28) du[0:6] = v du[7:13]= w for i in range(6): for j in range(6): if i != j: r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2) du[13+i] += j*(x[j] - x[i])/r du[20+i] += j*(y[j] - y[i])/r return du u0 = np.array([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0]) t = np.linspace(0, 3.0, 31) sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8) def time_func(): odeint(f, u0, t, rtol = 1e-8, atol=1e-8) timeit.Timer(time_func).timeit(number=100)/100 # 0.013212921999997889 seconds numba_f = numba.jit(f,nopython=True) def time_func(): odeint(numba_f, u0, t, rtol = 1e-8, atol=1e-8) timeit.Timer(time_func).timeit(number=100)/100 # 0.0025926070000059556 seconds import torch from torchdiffeq import odeint_adjoint as odeint import timeit @torch.jit.script class PleiadesODE(torch.nn.Module): def __init__(self): super(PleiadesODE, self).__init__() self.a = nn.Parameter(torch.as_tensor([1.0])) self.b = nn.Parameter(torch.as_tensor([1.0])) self.c = nn.Parameter(torch.as_tensor([1.0])) self.d = nn.Parameter(torch.as_tensor([1.0])) def forward(self, t, u): x = u[0:6] # x y = u[7:13] # y v = u[14:20] # x′ w = u[21:27] # y′ du = torch.zeros(28) du[0:6] = self.a * v du[7:13]= self.b * w for i in range(6): for j in range(6): if i != j: r = ((x[i]-x[j])**2 + (y[i] - y[j])**2)**(3/2) du[13+i] += self.c * j*(x[j] - x[i])/r du[20+i] += self.d * j*(y[j] - y[i])/r return du u0 = torch.tensor([3.0,3.0,-1.0,-3.0,2.0,-2.0,2.0,3.0,-3.0,2.0,0,0,-4.0,4.0,0,0,0,0,0,1.75,-1.5,0,0,0,-1.25,1,0,0]) t = torch.linspace(0, 3.0, 31) tmp = PleiadesODE() odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=10)/10 # 0.40516944999999394 seconds optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3) def time_grad(): optimizer.zero_grad() out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) loss = torch.mean(out) loss.backward() time_grad() timeit.Timer(time_grad).timeit(number=2)/2 # 6.3303714999992735 seconds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools using LoopVectorization const α₂ = 1.0 const α₃ = 1.0 const β₁ = 1.0 const β₂ = 1.0 const β₃ = 1.0 const r₁ = 1.0 const r₂ = 1.0 const D = 100.0 const γ₁ = 0.1 const γ₂ = 0.1 const γ₃ = 0.1 const N = 128 const X = reshape([i for i in 1:N for j in 1:N],N,N) const Y = reshape([j for i in 1:N for j in 1:N],N,N) α₁ = 1.0.*(X.>=4*N/5) const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]) const My = copy(Mx) Mx[2,1] = 2.0 Mx[end-1,end] = 2.0 My[1,2] = 2.0 My[end,end-1] = 2.0 #= # Define the discretized PDE as an ODE function const MyA = zeros(N,N) const AMx = zeros(N,N) const DA = zeros(N,N) function f(du,u,α₁,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] mul!(MyA,My,A) mul!(AMx,A,Mx) @. DA = D*(MyA + AMx) @. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C @. dB = α₂ - β₂*B - r₁*A*B + r₂*C @. dC = α₃ - β₃*C + r₁*A*B - r₂*C end =# function f!(_du,_u,_α₁,t) u = reshape(_u,N,N,3) du = reshape(_du,N,N,3) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] α₁ = reshape(_α₁,N,N) @inbounds for j in 2:N-1, i in 2:N-1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = 1 dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = N dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = 1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = N dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds begin i = 1; j = 1 dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = 1; j = N dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = 1 dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = N dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end end #= # Double check: u = rand(N,N,3) du = similar(u) du2 = similar(u) f( du ,u,nothing,0.0) f!(du2,u,nothing,0.0) =# u0 = zeros(N,N,3) prob = ODEProblem(f!,u0,(0.0,10.0),α₁) @btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 299.512 ms (3263 allocations: 251.83 MiB) @btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 830.273 ms (83 allocations: 10.51 MiB) @btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 907.863 ms (80 allocations: 9.38 MiB) function f(p) mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8)) end @btime ForwardDiff.gradient(f,α₁) const EIGEN_EST = Ref(0.0f0) EIGEN_EST[] = maximum(abs, eigvals(Matrix(My))) function fz(p) mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))) end @btime Zygote.gradient(fz,vec(α₁))
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint a2 = 1.0 a3 = 1.0 b1 = 1.0 b2 = 1.0 b3 = 1.0 r1 = 1.0 r2 = 1.0 _DD = 100.0 g1 = 0.1 g2 = 0.1 g3 = 0.1 N = 128 X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N)) Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N)) a1 = 1.0*(X>=4*N/5) Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) Mx[1,0] = 2.0 Mx[N-2,N-1] = 2.0 My[0,1] = 2.0 My[N-1,N-2] = 2.0 import torch from torchdiffeq import odeint_adjoint as odeint import timeit tMx = torch.from_numpy(Mx) tMy = torch.from_numpy(My) tX = torch.from_numpy(X) tY = torch.from_numpy(Y) ta1 = torch.from_numpy(a1) #A = torch.from_numpy(np.random.rand(N,N)) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! @torch.jit.script class ReactionDiffusionODE(torch.nn.Module): def __init__(self): super(ReactionDiffusionODE, self).__init__() self.ta1 = nn.Parameter(ta1) def forward(self, t, _u): u = torch.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] #MyA = tMy@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) #AMx = A@tMx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + ta1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return torch.flatten(torch.cat([dA,dB,dC])) tmp = ReactionDiffusionODE() u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64)) t = torch.linspace(0, 10, 101) sol = odeint(tmp, u0, t) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=1) # 9.293160799999896 seconds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools using LoopVectorization const α₂ = 1.0 const α₃ = 1.0 const β₁ = 1.0 const β₂ = 1.0 const β₃ = 1.0 const r₁ = 1.0 const r₂ = 1.0 const D = 100.0 const γ₁ = 0.1 const γ₂ = 0.1 const γ₃ = 0.1 const N = 16 const X = reshape([i for i in 1:N for j in 1:N],N,N) const Y = reshape([j for i in 1:N for j in 1:N],N,N) α₁ = 1.0.*(X.>=4*N/5) const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]) const My = copy(Mx) Mx[2,1] = 2.0 Mx[end-1,end] = 2.0 My[1,2] = 2.0 My[end,end-1] = 2.0 #= # Define the discretized PDE as an ODE function const MyA = zeros(N,N) const AMx = zeros(N,N) const DA = zeros(N,N) function f(du,u,α₁,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] mul!(MyA,My,A) mul!(AMx,A,Mx) @. DA = D*(MyA + AMx) @. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C @. dB = α₂ - β₂*B - r₁*A*B + r₂*C @. dC = α₃ - β₃*C + r₁*A*B - r₂*C end =# function f!(_du,_u,_α₁,t) u = reshape(_u,N,N,3) du = reshape(_du,N,N,3) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] α₁ = reshape(_α₁,N,N) @inbounds for j in 2:N-1, i in 2:N-1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = 1 dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = N dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = 1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = N dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds begin i = 1; j = 1 dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = 1; j = N dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = 1 dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = N dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end end #= # Double check: u = rand(N,N,3) du = similar(u) du2 = similar(u) f( du ,u,nothing,0.0) f!(du2,u,nothing,0.0) =# u0 = zeros(N,N,3) prob = ODEProblem(f!,u0,(0.0,5.0),α₁) @btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 3.300 ms (2725 allocations: 4.95 MiB) @btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 8.742 ms (152 allocations: 731.91 KiB) @btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=0.5); # 9.135 ms (152 allocations: 712.92 KiB) function f(p) mean(solve(prob,ROCK4(),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8)) end @btime ForwardDiff.gradient(f,α₁) # 1.228 s (304531 allocations: 1.37 GiB) function fz(p) mean(solve(prob,ROCK4(),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=ForwardSensitivity())) end @btime Zygote.gradient(fz,vec(α₁)) # 3.515 s (3413745 allocations: 1.10 GiB) function fz2(p) mean(solve(prob,ROCK4(),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=BacksolveAdjoint())) end @btime Zygote.gradient(fz2,vec(α₁)) # Diverges!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint a2 = 1.0 a3 = 1.0 b1 = 1.0 b2 = 1.0 b3 = 1.0 r1 = 1.0 r2 = 1.0 _DD = 100.0 g1 = 0.1 g2 = 0.1 g3 = 0.1 N = 16 X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N)) Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N)) a1 = 1.0*(X>=4*N/5) Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) Mx[1,0] = 2.0 Mx[N-2,N-1] = 2.0 My[0,1] = 2.0 My[N-1,N-2] = 2.0 u0 = np.ndarray.flatten(np.zeros((3,N,N))) #A = np.random.rand(N,N) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! # Define the discretized PDE as an ODE function def f(_u,t): u = np.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] # MyA = My@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) # AMx = A@Mx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + a1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return np.ndarray.flatten(np.concatenate([dA,dB,dC])) tspan = (0., 10.) t = np.linspace(0, 10, 101) sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000) import timeit def time_func(): odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000) timeit.Timer(time_func).timeit(number=1) # 2.2156629000000976 seconds import torch from torchdiffeq import odeint_adjoint as odeint import timeit tMx = torch.from_numpy(Mx) tMy = torch.from_numpy(My) tX = torch.from_numpy(X) tY = torch.from_numpy(Y) ta1 = torch.from_numpy(a1) #A = torch.from_numpy(np.random.rand(N,N)) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! @torch.jit.script class ReactionDiffusionODE(torch.nn.Module): def __init__(self): super(ReactionDiffusionODE, self).__init__() self.ta1 = nn.Parameter(ta1) def forward(self, t, _u): u = torch.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] #MyA = tMy@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) #AMx = A@tMx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + ta1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return torch.flatten(torch.cat([dA,dB,dC])) tmp = ReactionDiffusionODE() u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64)) t = torch.linspace(0, 10, 101) sol = odeint(tmp, u0, t) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=1) # 2.777176000000054 seconds optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3) def time_grad(): optimizer.zero_grad() out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) loss = torch.mean(out) loss.backward() time_grad() # AssertionError: underflow in dt 8.289381047241916e-16 timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 8.289381047241916e-16
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools using LoopVectorization const α₂ = 1.0 const α₃ = 1.0 const β₁ = 1.0 const β₂ = 1.0 const β₃ = 1.0 const r₁ = 1.0 const r₂ = 1.0 const D = 100.0 const γ₁ = 0.1 const γ₂ = 0.1 const γ₃ = 0.1 const N = 256 const X = reshape([i for i in 1:N for j in 1:N],N,N) const Y = reshape([j for i in 1:N for j in 1:N],N,N) α₁ = 1.0.*(X.>=4*N/5) const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]) const My = copy(Mx) Mx[2,1] = 2.0 Mx[end-1,end] = 2.0 My[1,2] = 2.0 My[end,end-1] = 2.0 #= # Define the discretized PDE as an ODE function const MyA = zeros(N,N) const AMx = zeros(N,N) const DA = zeros(N,N) function f(du,u,α₁,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] mul!(MyA,My,A) mul!(AMx,A,Mx) @. DA = D*(MyA + AMx) @. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C @. dB = α₂ - β₂*B - r₁*A*B + r₂*C @. dC = α₃ - β₃*C + r₁*A*B - r₂*C end =# function f!(_du,_u,_α₁,t) u = reshape(_u,N,N,3) du = reshape(_du,N,N,3) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] α₁ = reshape(_α₁,N,N) @inbounds for j in 2:N-1, i in 2:N-1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = 1 dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = N dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = 1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = N dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds begin i = 1; j = 1 dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = 1; j = N dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = 1 dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = N dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end end #= # Double check: u = rand(N,N,3) du = similar(u) du2 = similar(u) f( du ,u,nothing,0.0) f!(du2,u,nothing,0.0) =# u0 = zeros(N,N,3) prob = ODEProblem(f!,u0,(0.0,10.0),α₁) @btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 1.586 s (3239 allocations: 988.71 MiB) @btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 5.285 s (83 allocations: 42.01 MiB) @btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 6.195 s (80 allocations: 37.51 MiB) function f(p) mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8)) end @btime ForwardDiff.gradient(f,α₁) const EIGEN_EST = Ref(0.0f0) EIGEN_EST[] = maximum(abs, eigvals(Matrix(My))) function fz(p) mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))) end @btime Zygote.gradient(fz,vec(α₁))
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint a2 = 1.0 a3 = 1.0 b1 = 1.0 b2 = 1.0 b3 = 1.0 r1 = 1.0 r2 = 1.0 _DD = 100.0 g1 = 0.1 g2 = 0.1 g3 = 0.1 N = 256 X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N)) Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N)) a1 = 1.0*(X>=4*N/5) Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) Mx[1,0] = 2.0 Mx[N-2,N-1] = 2.0 My[0,1] = 2.0 My[N-1,N-2] = 2.0 import torch from torchdiffeq import odeint_adjoint as odeint import timeit tMx = torch.from_numpy(Mx) tMy = torch.from_numpy(My) tX = torch.from_numpy(X) tY = torch.from_numpy(Y) ta1 = torch.from_numpy(a1) #A = torch.from_numpy(np.random.rand(N,N)) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! @torch.jit.script class ReactionDiffusionODE(torch.nn.Module): def __init__(self): super(ReactionDiffusionODE, self).__init__() self.ta1 = nn.Parameter(ta1) def forward(self, t, _u): u = torch.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] #MyA = tMy@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) #AMx = A@tMx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + ta1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return torch.flatten(torch.cat([dA,dB,dC])) tmp = ReactionDiffusionODE() u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64)) t = torch.linspace(0, 10, 101) sol = odeint(tmp, u0, t) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=1) # 37.48328430000038 seconds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools using LoopVectorization const α₂ = 1.0 const α₃ = 1.0 const β₁ = 1.0 const β₂ = 1.0 const β₃ = 1.0 const r₁ = 1.0 const r₂ = 1.0 const D = 100.0 const γ₁ = 0.1 const γ₂ = 0.1 const γ₃ = 0.1 const N = 32 const X = reshape([i for i in 1:N for j in 1:N],N,N) const Y = reshape([j for i in 1:N for j in 1:N],N,N) α₁ = 1.0.*(X.>=4*N/5) const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]) const My = copy(Mx) Mx[2,1] = 2.0 Mx[end-1,end] = 2.0 My[1,2] = 2.0 My[end,end-1] = 2.0 #= # Define the discretized PDE as an ODE function const MyA = zeros(N,N) const AMx = zeros(N,N) const DA = zeros(N,N) function f(du,u,α₁,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] mul!(MyA,My,A) mul!(AMx,A,Mx) @. DA = D*(MyA + AMx) @. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C @. dB = α₂ - β₂*B - r₁*A*B + r₂*C @. dC = α₃ - β₃*C + r₁*A*B - r₂*C end =# function f!(_du,_u,_α₁,t) u = reshape(_u,N,N,3) du = reshape(_du,N,N,3) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] α₁ = reshape(_α₁,N,N) @inbounds for j in 2:N-1, i in 2:N-1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = 1 dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = N dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = 1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = N dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds begin i = 1; j = 1 dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = 1; j = N dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = 1 dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = N dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end end #= # Double check: u = rand(N,N,3) du = similar(u) du2 = similar(u) f( du ,u,nothing,0.0) f!(du2,u,nothing,0.0) =# u0 = zeros(N,N,3) prob = ODEProblem(f!,u0,(0.0,10.0),α₁) @btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 14.397 ms (3311 allocations: 16.50 MiB) @btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 37.298 ms (83 allocations: 679.31 KiB) @btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 38.608 ms (80 allocations: 606.47 KiB) function f(p) mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8)) end @btime ForwardDiff.gradient(f,α₁) # 18.495 s (1124288 allocations: 13.90 GiB) const EIGEN_EST = Ref(0.0f0) EIGEN_EST[] = maximum(abs, eigvals(Matrix(My))) function fz(p) mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))) end @btime Zygote.gradient(fz,vec(α₁)) # 27.777 s (774420 allocations: 447.04 MiB)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint a2 = 1.0 a3 = 1.0 b1 = 1.0 b2 = 1.0 b3 = 1.0 r1 = 1.0 r2 = 1.0 _DD = 100.0 g1 = 0.1 g2 = 0.1 g3 = 0.1 N = 32 X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N)) Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N)) a1 = 1.0*(X>=4*N/5) Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) Mx[1,0] = 2.0 Mx[N-2,N-1] = 2.0 My[0,1] = 2.0 My[N-1,N-2] = 2.0 u0 = np.ndarray.flatten(np.zeros((3,N,N))) #A = np.random.rand(N,N) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! # Define the discretized PDE as an ODE function def f(_u,t): u = np.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] # MyA = My@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) # AMx = A@Mx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + a1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return np.ndarray.flatten(np.concatenate([dA,dB,dC])) tspan = (0., 10.) t = np.linspace(0, 10, 101) sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000) import timeit def time_func(): odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000) timeit.Timer(time_func).timeit(number=1) # 6.712808100000984 seconds import torch from torchdiffeq import odeint_adjoint as odeint import timeit tMx = torch.from_numpy(Mx) tMy = torch.from_numpy(My) tX = torch.from_numpy(X) tY = torch.from_numpy(Y) ta1 = torch.from_numpy(a1) #A = torch.from_numpy(np.random.rand(N,N)) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! @torch.jit.script class ReactionDiffusionODE(torch.nn.Module): def __init__(self): super(ReactionDiffusionODE, self).__init__() self.ta1 = nn.Parameter(ta1) def forward(self, t, _u): u = torch.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] #MyA = tMy@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) #AMx = A@tMx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + ta1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return torch.flatten(torch.cat([dA,dB,dC])) tmp = ReactionDiffusionODE() u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64)) t = torch.linspace(0, 10, 101) sol = odeint(tmp, u0, t) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=1) # 3.120565799999895 seconds optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3) def time_grad(): optimizer.zero_grad() out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) loss = torch.mean(out) loss.backward() time_grad() # AssertionError: underflow in dt 8.230862228687043e-16 timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 8.230862228687043e-16
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools using LoopVectorization const α₂ = 1.0 const α₃ = 1.0 const β₁ = 1.0 const β₂ = 1.0 const β₃ = 1.0 const r₁ = 1.0 const r₂ = 1.0 const D = 100.0 const γ₁ = 0.1 const γ₂ = 0.1 const γ₃ = 0.1 const N = 512 const X = reshape([i for i in 1:N for j in 1:N],N,N) const Y = reshape([j for i in 1:N for j in 1:N],N,N) const α₁ = 1.0.*(X.>=4*N/5) const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]) const My = copy(Mx) Mx[2,1] = 2.0 Mx[end-1,end] = 2.0 My[1,2] = 2.0 My[end,end-1] = 2.0 # Define the discretized PDE as an ODE function const MyA = zeros(N,N) const AMx = zeros(N,N) const DA = zeros(N,N) function f(du,u,p,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] mul!(MyA,My,A) mul!(AMx,A,Mx) @. DA = D*(MyA + AMx) @. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C @. dB = α₂ - β₂*B - r₁*A*B + r₂*C @. dC = α₃ - β₃*C + r₁*A*B - r₂*C end function f!(du,u,p,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] @inbounds for j in 2:N-1, i in 2:N-1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = 1 dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = N dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = 1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = N dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds begin i = 1; j = 1 dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = 1; j = N dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = 1 dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = N dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end end #= # Double check: u = rand(N,N,3) du = similar(u) du2 = similar(u) f( du ,u,nothing,0.0) f!(du2,u,nothing,0.0) =# u0 = zeros(N,N,3) prob = ODEProblem(f!,u0,(0.0,10.0)) @btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 10.261 s (3406 allocations: 4.34 GiB) @btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 27.946 s (266 allocations: 708.02 MiB) @btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 29.294 s (263 allocations: 690.02 MiB)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint a2 = 1.0 a3 = 1.0 b1 = 1.0 b2 = 1.0 b3 = 1.0 r1 = 1.0 r2 = 1.0 _DD = 100.0 g1 = 0.1 g2 = 0.1 g3 = 0.1 N = 512 X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N)) Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N)) a1 = 1.0*(X>=4*N/5) Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) Mx[1,0] = 2.0 Mx[N-2,N-1] = 2.0 My[0,1] = 2.0 My[N-1,N-2] = 2.0 import torch from torchdiffeq import odeint_adjoint as odeint import timeit tMx = torch.from_numpy(Mx) tMy = torch.from_numpy(My) tX = torch.from_numpy(X) tY = torch.from_numpy(Y) ta1 = torch.from_numpy(a1) #A = torch.from_numpy(np.random.rand(N,N)) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! @torch.jit.script class ReactionDiffusionODE(torch.nn.Module): def __init__(self): super(ReactionDiffusionODE, self).__init__() self.ta1 = nn.Parameter(ta1) def forward(self, t, _u): u = torch.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] #MyA = tMy@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) #AMx = A@tMx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + ta1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return torch.flatten(torch.cat([dA,dB,dC])) tmp = ReactionDiffusionODE() u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64)) t = torch.linspace(0, 10, 101) sol = odeint(tmp, u0, t) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=1) # 172.58722800000032 seconds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using OrdinaryDiffEq, LinearAlgebra, SparseArrays, BenchmarkTools using LoopVectorization const α₂ = 1.0 const α₃ = 1.0 const β₁ = 1.0 const β₂ = 1.0 const β₃ = 1.0 const r₁ = 1.0 const r₂ = 1.0 const D = 100.0 const γ₁ = 0.1 const γ₂ = 0.1 const γ₃ = 0.1 const N = 64 const X = reshape([i for i in 1:N for j in 1:N],N,N) const Y = reshape([j for i in 1:N for j in 1:N],N,N) α₁ = 1.0.*(X.>=4*N/5) const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]) const My = copy(Mx) Mx[2,1] = 2.0 Mx[end-1,end] = 2.0 My[1,2] = 2.0 My[end,end-1] = 2.0 #= # Define the discretized PDE as an ODE function const MyA = zeros(N,N) const AMx = zeros(N,N) const DA = zeros(N,N) function f(du,u,α₁,t) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] mul!(MyA,My,A) mul!(AMx,A,Mx) @. DA = D*(MyA + AMx) @. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C @. dB = α₂ - β₂*B - r₁*A*B + r₂*C @. dC = α₃ - β₃*C + r₁*A*B - r₂*C end =# function f!(_du,_u,_α₁,t) u = reshape(_u,N,N,3) du = reshape(_du,N,N,3) A = @view u[:,:,1] B = @view u[:,:,2] C = @view u[:,:,3] dA = @view du[:,:,1] dB = @view du[:,:,2] dC = @view du[:,:,3] α₁ = reshape(_α₁,N,N) @inbounds for j in 2:N-1, i in 2:N-1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = 1 dA[1,j] = D*(2A[i+1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for j in 2:N-1 i = N dA[end,j] = D*(2A[i-1,j] + A[i,j+1] + A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = 1 dA[i,j] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds for i in 2:N-1 j = N dA[i,end] = D*(A[i-1,j] + A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end @inbounds begin i = 1; j = 1 dA[i,j] = D*(2A[i+1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = 1; j = N dA[i,j] = D*(2A[i+1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = 1 dA[i,j] = D*(2A[i-1,j] + 2A[i,j+1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] i = N; j = N dA[i,j] = D*(2A[i-1,j] + 2A[i,j-1] - 4A[i,j]) + α₁[i,j] - β₁*A[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dB[i,j] = α₂ - β₂*B[i,j] - r₁*A[i,j]*B[i,j] + r₂*C[i,j] dC[i,j] = α₃ - β₃*C[i,j] + r₁*A[i,j]*B[i,j] - r₂*C[i,j] end end #= # Double check: u = rand(N,N,3) du = similar(u) du2 = similar(u) f( du ,u,nothing,0.0) f!(du2,u,nothing,0.0) =# u0 = zeros(N,N,3) prob = ODEProblem(f!,u0,(0.0,10.0),α₁) @btime solve(prob, ROCK4(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 64.192 ms (3287 allocations: 64.24 MiB) @btime solve(prob, Tsit5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 180.134 ms (83 allocations: 2.63 MiB) @btime solve(prob, DP5(), reltol = 1e-8, abstol=1e-8, saveat=1.0); # 192.216 ms (80 allocations: 2.35 MiB) function f(p) mean(solve(prob,ROCK4(),p=p,saveat=1.0,reltol=1e-8,abstol=1e-8)) end @btime ForwardDiff.gradient(f,α₁) const EIGEN_EST = Ref(0.0f0) EIGEN_EST[] = maximum(abs, eigvals(Matrix(My))) function fz(p) mean(solve(prob,ROCK4(eigen_est = (integ)->integ.eigen_est = EIGEN_EST[]),u0=vec(prob.u0),p=p,saveat=0.5,reltol=1e-8,abstol=1e-8,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))) end @btime Zygote.gradient(fz,vec(α₁))
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from scipy.integrate import odeint a2 = 1.0 a3 = 1.0 b1 = 1.0 b2 = 1.0 b3 = 1.0 r1 = 1.0 r2 = 1.0 _DD = 100.0 g1 = 0.1 g2 = 0.1 g3 = 0.1 N = 64 X = np.reshape([j+1 for i in range(N) for j in range(N)],(N,N)) Y = np.reshape([i+1 for i in range(N) for j in range(N)],(N,N)) a1 = 1.0*(X>=4*N/5) Mx = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) My = np.diag([1.0 for i in range(N-1)],-1) + np.diag([-2.0 for i in range(N)],0) + np.diag([1.0 for i in range(N-1)],1) Mx[1,0] = 2.0 Mx[N-2,N-1] = 2.0 My[0,1] = 2.0 My[N-1,N-2] = 2.0 u0 = np.ndarray.flatten(np.zeros((3,N,N))) #A = np.random.rand(N,N) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - My@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! # Define the discretized PDE as an ODE function def f(_u,t): u = np.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] # MyA = My@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = np.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) # AMx = A@Mx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = np.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + a1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return np.ndarray.flatten(np.concatenate([dA,dB,dC])) tspan = (0., 10.) t = np.linspace(0, 10, 101) sol = odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000) import timeit def time_func(): odeint(f, u0, t, rtol = 1e-8, atol=1e-8,mxstep=1000) timeit.Timer(time_func).timeit(number=1) # 173.78868460000012 seconds import torch from torchdiffeq import odeint_adjoint as odeint import timeit tMx = torch.from_numpy(Mx) tMy = torch.from_numpy(My) tX = torch.from_numpy(X) tY = torch.from_numpy(Y) ta1 = torch.from_numpy(a1) #A = torch.from_numpy(np.random.rand(N,N)) #top = -2*A[0,:] + 2*A[1,:] #bottom = 2*A[N-2,:] - 2*A[N-1,:] #torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) - tMy@A # all zero! #left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) #right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) #torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) - A@Mx # all zero! @torch.jit.script class ReactionDiffusionODE(torch.nn.Module): def __init__(self): super(ReactionDiffusionODE, self).__init__() self.ta1 = nn.Parameter(ta1) def forward(self, t, _u): u = torch.reshape(_u,(3,N,N)) A = u[0,:,:] B = u[1,:,:] C = u[2,:,:] #MyA = tMy@A top = -2*A[0,:] + 2*A[1,:] bottom = 2*A[N-2,:] - 2*A[N-1,:] MyA = torch.vstack((top,A[0:N-2,:] - 2*A[1:N-1,:] + A[2:N,:],bottom)) #AMx = A@tMx left = (-2*A[:,0] + 2*A[:,1]).reshape(N,1) right = (2*A[:,N-2] - 2*A[:,N-1]).reshape(N,1) AMx = torch.hstack((left,A[:,0:N-2] - 2*A[:,1:N-1] + A[:,2:N],right)) DA = _DD*(MyA + AMx) dA = DA + ta1 - b1*A - r1*A*B + r2*C dB = a2 - b2*B - r1*A*B + r2*C dC = a3 - b3*C + r1*A*B - r2*C return torch.flatten(torch.cat([dA,dB,dC])) tmp = ReactionDiffusionODE() u0 = torch.flatten(torch.zeros((3,N,N),dtype=torch.float64)) t = torch.linspace(0, 10, 101) sol = odeint(tmp, u0, t) def time_func(): with torch.no_grad(): odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) time_func() timeit.Timer(time_func).timeit(number=1) # 5.241670299998077 seconds optimizer = optim.RMSprop(tmp.parameters(), lr=1e-3) def time_grad(): optimizer.zero_grad() out = odeint(tmp, u0, t, rtol = 1e-8, atol=1e-8) loss = torch.mean(out) loss.backward() time_grad() # AssertionError: underflow in dt 7.349905693271135e-16 timeit.Timer(time_grad).timeit(number=2)/2 # AssertionError: underflow in dt 7.349905693271135e-16

### toollu commented Aug 10, 2020 • edited

I'm relatively new to both python and Julia, so bear with me.
First of all thanks for your amazing work on the diffeqflux library, it looks super promising for my purpose.
However, when I run lorenz.py as it is I get the following trace:

``````Traceback (most recent call last):
File "<input>", line 1, in <module>
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/user/SRC/scratch.py", line 39, in <module>
class LorenzODE(torch.nn.Module):
File "/opt/anaconda3/envs/DSP/lib/python3.7/site-packages/torch/jit/__init__.py", line 1525, in script
RuntimeError: Type '<class '__main__.LorenzODE'>' cannot be compiled since it inherits from nn.Module, pass an instance instead.
``````

Without the @torch.jit.script it runs fine however. Is it the code or is it my environment?

### ChrisRackauckas commented Aug 10, 2020

The torch jit has some limitations on how you run the JIT'd code IIRC. I think the easiest one to run was https://gist.github.com/ChrisRackauckas/cc6ac746e2dfd285c28e0584a2bfd320/revisions#diff-0cb4ce6da1b7bafa3e58087c7e9759a6 , so I might just revive that version.

### chacowingnut commented Jul 8, 2021

Howdy, I noticed that you're calling the old `odeint`, which is deprecated. I tried reproducing your Lorenz Python results using scipy's newer `solve_ivp`, and achieved a dramatic speed boost. On my machine the forward pass timing improved from 45ish ms to ~1.1 ms. Using DOP853 via `solve_ivp`'s method kwarg I get timings just a hair above half a ms.

```import numpy as np
from scipy.integrate import solve_ivp
import timeit

def f(t, u, sigma, rho, beta):
x, y, z = u
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)
sol = solve_ivp(f, t, u0, args=(10.0,28.0,8/3))

def time_func():
solve_ivp(f, t, u0, args=(10.0,28.0,8/3), rtol=1e-8, atol=1e-8)

print(timeit.Timer(time_func).timeit(number=100)/100)  # 0.0011336688511073589 seconds```

### ChrisRackauckas commented Jul 10, 2021

The reason the timings are using `odeint` is because `odeint` is known to be a lot faster than `solve_ivp`. So then why is `solve_ivp` faster in your example? Well, because it's there's an error in the code 😅. If you check the code you wrote:

```import numpy as np
from scipy.integrate import solve_ivp
import timeit

def f(t, u, sigma, rho, beta):
x, y, z = u
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
t = np.linspace(0, 100, 1001)

sol = solve_ivp(f, t, u0, args=(10.0,28.0,8/3), rtol=1e-8, atol=1e-8, method="RK45")
len(sol.t) # 18```

You see that it doesn't act quite the way that you would expect. You asked for the values at `t` but it gave you 18 values instead of 1001. You need to add the keyword argument `t_eval` in order to control the output in `solve_ivp`:

```sol = solve_ivp(f, tspan, u0, args=(10.0,28.0,8/3), t_eval = t, rtol=1e-8, atol=1e-8, method="RK45")
len(sol.t) # 1001```

and what's the timing with the corrected arguments?

```def time_func():
solve_ivp(f, tspan, u0, args=(10.0,28.0,8/3), t_eval = t, rtol=1e-8, atol=1e-8, method="RK45")

print(timeit.Timer(time_func).timeit(number=100)/100)  # 0.8685612489999994 seconds```

`DOP853` then is similar:

```def time_func():
solve_ivp(f, tspan, u0, args=(10.0,28.0,8/3), t_eval = t, rtol=1e-8, atol=1e-8, method="DOP853")

print(timeit.Timer(time_func).timeit(number=100)/100)  # 0.5002750910000009 seconds```

This particular test is stress testing more of the implementation than the algorithm itself.

### chacowingnut commented Jul 10, 2021

Haha I see. Many thanks for the response and for your contributions in this area!

### ChrisRackauckas commented Jul 10, 2021

No problem, and thanks for your contribution. Noting the `solve_ivp` time is definitely important.