Skip to content

Instantly share code, notes, and snippets.

@ChrisRackauckas
Last active September 28, 2022 20:48
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChrisRackauckas/4a4d526c15cc4170ce37da837bfc32c4 to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/4a4d526c15cc4170ce37da837bfc32c4 to your computer and use it in GitHub Desktop.
torchdiffeq vs Julia DiffEqflux Neural ODE Training Benchmark

torchdiffeq vs Julia DiffEqFlux Neural ODE Training Benchmark

The spiral neural ODE was used as the training benchmark for both torchdiffeq (Python) and DiffEqFlux (Julia) which utilized the same architecture and 500 steps of ADAM. Both achived similar objective values at the end. Results:

  • DiffEqFlux defaults: 7.4 seconds
  • DiffEqFlux optimized: 2.7 seconds
  • torchdiffeq: 288.965871299999 seconds

Relative time to train for Python vs Julia's DiffEqFlux (lower is better)

Unoptimized defaults: 39x Optimized sensitivity: 107x

Final loss values:

  • DiffEqFlux defaults: 4.895287e-02
  • DiffEqFlux optimized: 2.761669e-02
  • torchdiffeq: 0.0596
import numpy as np
import timeit
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint_adjoint as odeint
true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 1.5, 30)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])
class Lambda(nn.Module):
def forward(self, t, y):
return torch.mm(y**3, true_A)
with torch.no_grad():
true_y = odeint(Lambda(), true_y0, t, method='dopri5')
class ODEFunc(nn.Module):
def __init__(self):
super(ODEFunc, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 2),
)
for m in self.net.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, val=0)
def forward(self, t, y):
return self.net(y**3)
func = ODEFunc()
optimizer = optim.Adam(func.parameters(), lr=0.05)
def time_func():
for itr in range(1, 501):
optimizer.zero_grad()
pred_y = odeint(func, true_y0, t)
loss = torch.sum((pred_y - true_y)**2)
loss.backward()
optimizer.step()
with torch.no_grad():
pred_y = odeint(func, true_y0, t)
print(torch.sum((pred_y - true_y)**2))
time_func()
func = ODEFunc()
optimizer = optim.Adam(func.parameters(), lr=0.05)
def time_func():
for itr in range(1, 501):
optimizer.zero_grad()
pred_y = odeint(func, true_y0, t)
loss = torch.sum((pred_y - true_y)**2)
loss.backward()
optimizer.step()
timeit.Timer(time_func).timeit(number=1) # 288.965871299999 seconds
with torch.no_grad():
pred_y = odeint(func, true_y0, t)
print(torch.sum((pred_y - true_y)**2))
# tensor(0.0596)
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = FastChain((x, p) -> x.^3,
FastDense(2, 50, tanh),
FastDense(50, 2))
neural_ode_f(u,p,t) = dudt2(u,p)
pinit = initial_params(dudt2)
prob = ODEProblem(neural_ode_f, u0, tspan, pinit)
function predict_neuralode(p)
tmp_prob = remake(prob,p=p)
Array(solve(tmp_prob,Tsit5(),saveat=tsteps))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
callback = function (p, l, pred; doplot = true)
#display(l)
# plot current prediction against data
#plt = scatter(tsteps, ode_data[1,:], label = "data")
#scatter!(plt, tsteps, pred[1,:], label = "prediction")
#if doplot
# display(plot(plt))
#end
return false
end
@time result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, pinit,
ADAM(0.05), cb = callback,
maxiters = 500)
#=
7.359051 seconds (53.22 M allocations: 5.284 GiB, 15.72% gc time)
* Status: success
* Candidate solution
Final objective value: 4.895287e-02
* Found with
Algorithm: ADAM
* Convergence measures
|x - x'| = NaN ≰ 0.0e+00
|x - x'|/|x'| = NaN ≰ 0.0e+00
|f(x) - f(x')| = NaN ≰ 0.0e+00
|f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00
|g(x)| = NaN ≰ 0.0e+00
* Work counters
Seconds run: 7 (vs limit Inf)
Iterations: 500
f(x) calls: 500
∇f(x) calls: 500
=#
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, DiffEqSensitivity
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = FastChain((x, p) -> x.^3,
FastDense(2, 50, tanh),
FastDense(50, 2))
neural_ode_f(u,p,t) = dudt2(u,p)
pinit = initial_params(dudt2)
prob = ODEProblem(neural_ode_f, u0, tspan, pinit)
function predict_neuralode(p)
tmp_prob = remake(prob,p=p)
Array(solve(tmp_prob,Tsit5(),saveat=tsteps,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(true))))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
callback = function (p, l, pred; doplot = true)
#display(l)
# plot current prediction against data
#plt = scatter(tsteps, ode_data[1,:], label = "data")
#scatter!(plt, tsteps, pred[1,:], label = "prediction")
#if doplot
# display(plot(plt))
#end
return false
end
@time result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, pinit,
ADAM(0.05), cb = callback,
maxiters = 500)
#=
2.687161 seconds (17.79 M allocations: 1002.418 MiB, 7.41% gc time)
* Status: success
* Candidate solution
Final objective value: 2.761669e-02
* Found with
Algorithm: ADAM
* Convergence measures
|x - x'| = NaN ≰ 0.0e+00
|x - x'|/|x'| = NaN ≰ 0.0e+00
|f(x) - f(x')| = NaN ≰ 0.0e+00
|f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00
|g(x)| = NaN ≰ 0.0e+00
* Work counters
Seconds run: 3 (vs limit Inf)
Iterations: 500
f(x) calls: 500
∇f(x) calls: 500
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment