torchdiffeq vs Julia DiffEqflux Neural ODE Training Benchmark

The spiral neural ODE was used as the training benchmark for both torchdiffeq (Python) and DiffEqFlux (Julia) which utilized the same architecture and 500 steps of ADAM. Both achived similar objective values at the end. Results:

• DiffEqFlux defaults: 7.4 seconds
• DiffEqFlux optimized: 2.7 seconds
• torchdiffeq: 288.965871299999 seconds

Relative time to train for Python vs Julia's DiffEqFlux (lower is better)

Unoptimized defaults: 39x Optimized sensitivity: 107x

Final loss values:

• DiffEqFlux defaults: 4.895287e-02
• DiffEqFlux optimized: 2.761669e-02
• torchdiffeq: 0.0596
 import numpy as np import timeit import torch import torch.nn as nn import torch.optim as optim from torchdiffeq import odeint_adjoint as odeint true_y0 = torch.tensor([[2., 0.]]) t = torch.linspace(0., 1.5, 30) true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]) class Lambda(nn.Module): def forward(self, t, y): return torch.mm(y**3, true_A) with torch.no_grad(): true_y = odeint(Lambda(), true_y0, t, method='dopri5') class ODEFunc(nn.Module): def __init__(self): super(ODEFunc, self).__init__() self.net = nn.Sequential( nn.Linear(2, 50), nn.Tanh(), nn.Linear(50, 2), ) for m in self.net.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, val=0) def forward(self, t, y): return self.net(y**3) func = ODEFunc() optimizer = optim.Adam(func.parameters(), lr=0.05) def time_func(): for itr in range(1, 501): optimizer.zero_grad() pred_y = odeint(func, true_y0, t) loss = torch.sum((pred_y - true_y)**2) loss.backward() optimizer.step() with torch.no_grad(): pred_y = odeint(func, true_y0, t) print(torch.sum((pred_y - true_y)**2)) time_func() func = ODEFunc() optimizer = optim.Adam(func.parameters(), lr=0.05) def time_func(): for itr in range(1, 501): optimizer.zero_grad() pred_y = odeint(func, true_y0, t) loss = torch.sum((pred_y - true_y)**2) loss.backward() optimizer.step() timeit.Timer(time_func).timeit(number=1) # 288.965871299999 seconds with torch.no_grad(): pred_y = odeint(func, true_y0, t) print(torch.sum((pred_y - true_y)**2)) # tensor(0.0596)
 using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots u0 = Float32[2.0; 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) tsteps = range(tspan[1], tspan[2], length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] du .= ((u.^3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) dudt2 = FastChain((x, p) -> x.^3, FastDense(2, 50, tanh), FastDense(50, 2)) neural_ode_f(u,p,t) = dudt2(u,p) pinit = initial_params(dudt2) prob = ODEProblem(neural_ode_f, u0, tspan, pinit) function predict_neuralode(p) tmp_prob = remake(prob,p=p) Array(solve(tmp_prob,Tsit5(),saveat=tsteps)) end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) return loss, pred end callback = function (p, l, pred; doplot = true) #display(l) # plot current prediction against data #plt = scatter(tsteps, ode_data[1,:], label = "data") #scatter!(plt, tsteps, pred[1,:], label = "prediction") #if doplot # display(plot(plt)) #end return false end @time result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, pinit, ADAM(0.05), cb = callback, maxiters = 500) #= 7.359051 seconds (53.22 M allocations: 5.284 GiB, 15.72% gc time) * Status: success * Candidate solution Final objective value: 4.895287e-02 * Found with Algorithm: ADAM * Convergence measures |x - x'| = NaN ≰ 0.0e+00 |x - x'|/|x'| = NaN ≰ 0.0e+00 |f(x) - f(x')| = NaN ≰ 0.0e+00 |f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00 |g(x)| = NaN ≰ 0.0e+00 * Work counters Seconds run: 7 (vs limit Inf) Iterations: 500 f(x) calls: 500 ∇f(x) calls: 500 =#
 using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, DiffEqSensitivity u0 = Float32[2.0; 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) tsteps = range(tspan[1], tspan[2], length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] du .= ((u.^3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) dudt2 = FastChain((x, p) -> x.^3, FastDense(2, 50, tanh), FastDense(50, 2)) neural_ode_f(u,p,t) = dudt2(u,p) pinit = initial_params(dudt2) prob = ODEProblem(neural_ode_f, u0, tspan, pinit) function predict_neuralode(p) tmp_prob = remake(prob,p=p) Array(solve(tmp_prob,Tsit5(),saveat=tsteps,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(true)))) end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) return loss, pred end callback = function (p, l, pred; doplot = true) #display(l) # plot current prediction against data #plt = scatter(tsteps, ode_data[1,:], label = "data") #scatter!(plt, tsteps, pred[1,:], label = "prediction") #if doplot # display(plot(plt)) #end return false end @time result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, pinit, ADAM(0.05), cb = callback, maxiters = 500) #= 2.687161 seconds (17.79 M allocations: 1002.418 MiB, 7.41% gc time) * Status: success * Candidate solution Final objective value: 2.761669e-02 * Found with Algorithm: ADAM * Convergence measures |x - x'| = NaN ≰ 0.0e+00 |x - x'|/|x'| = NaN ≰ 0.0e+00 |f(x) - f(x')| = NaN ≰ 0.0e+00 |f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00 |g(x)| = NaN ≰ 0.0e+00 * Work counters Seconds run: 3 (vs limit Inf) Iterations: 500 f(x) calls: 500 ∇f(x) calls: 500 =#