Instantly share code, notes, and snippets.

ChrisRackauckas/diffeq_vs_torchsde.md

Last active September 28, 2022 20:49
torchsde vs DifferentialEquations.jl / DiffEqFlux.jl (Julia) benchmarks

torchsde vs DifferentialEquations.jl / DiffEqFlux.jl (Julia)

This example is a 4-dimensional geometric brownian motion. The code for the torchsde version is pulled directly from the torchsde README so that it would be a fair comparison against the author's own code. The only change to that example is the addition of a dt choice so that the simulation method and time step matches between the two different programs.

The SDE is solved 100 times. The summary of the results is as follows:

• torchsde: 1.87 seconds
• DifferentialEquations.jl: 0.00115 seconds

This demonstrates a 1,600x performance difference in favor of Julia on the Python library's README example. Further testing against torchsde was not able to be completed because of these performance issues.

We note that the performance difference in the context of neural SDEs is likely smaller due to the ability to time spent in matrix multiplication kernels. However, given that full SDE training examples like demonstrated here generally take about a minute, we still highly expect a major performance difference but currently do not have the compute time to run a full demonstration.

This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using StochasticDiffEq, StaticArrays const μ = 0.5ones(4) const σ = 0.1ones(4) f(du,u,p,t) = du .= μ .* u g(du,u,p,t) = du .= σ .* u u0 = 0.1ones(4) tspan = (0.0,1.0) saveat = range(0.0,1.0,length=20) prob = SDEProblem(f,g,u0,tspan) @time for i in 1:100 sol = solve(prob,SRIW1(),adaptive=false,dt=saveat[2]) end #0.001154 seconds (16.50 k allocations: 1.109 MiB) #0.001162 seconds (16.50 k allocations: 1.109 MiB) #0.001148 seconds (16.50 k allocations: 1.109 MiB)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import torch from torchsde import sdeint class SDE(torch.nn.Module): def __init__(self, mu, sigma): super().__init__() self.noise_type="diagonal" self.sde_type = "ito" self.mu = mu self.sigma = sigma def f(self, t, y): return self.mu * y def g(self, t, y): return self.sigma * y batch_size, d, m = 4, 1, 1 # State dimension d, Brownian motion dimension m. geometric_bm = SDE(mu=0.5, sigma=1) y0 = torch.zeros(batch_size, d).fill_(0.1) # Initial state. ts = torch.linspace(0, 1, 20) ys = sdeint(geometric_bm, y0, ts) def time_func(): ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': False}) timeit.Timer(time_func).timeit(number=100) # 1.8681289999999535 seconds # 1.8695188000001508 seconds

lxuechen commented Aug 4, 2020

Speaking of jitting, here's an example of how one would jit the SDE:

import torch
from torchsde import sdeint
import timeit

class SDE(torch.nn.Module):

def __init__(self, mu, sigma):
super().__init__()
self.noise_type = "diagonal"
self.sde_type = "ito"

self.mu = mu
self.sigma = sigma

@torch.jit.export
def f(self, t, y):
return self.mu * y

@torch.jit.export
def g(self, t, y):
return self.sigma * y

batch_size, d, m = 4, 1, 1  # State dimension d, Brownian motion dimension m.
geometric_bm = SDE(mu=0.5, sigma=1)

# Works for torch==1.6.0.
geometric_bm = torch.jit.script(geometric_bm)

y0 = torch.zeros(batch_size, d).fill_(0.1)  # Initial state.
ts = torch.linspace(0, 1, 20)

def time_func():
ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': False})

print(timeit.Timer(time_func).timeit(number=100))

For now scripting actually makes it slightly slower. Though in the ideal scenario when there's some control-flow or indexing in f and g, there might be an improvement.

ChrisRackauckas commented Aug 4, 2020

Thanks for the confirmation of the JIT script. Yeah I was trying it, but I think torchjit really works well on a different regime. Good to know I had that right.

BTW, what exactly is the trapezoidal approximation there? Is that for filling in gaps of the Brownian tree?

lxuechen commented Aug 6, 2020 • edited

The trapezoidal approximation estimates \int_{t_k}^{t_{k+1}} W_s \ds based solely on Brownian motion increments. It will be replaced soon with something faster that tracks two sequences of random variables like those RSwM algorithms you have.

For fixed-step solvers, we don't necessarily need to use an approximation, since the integrated term can be sampled exactly. For adaptive solvers, since there might be backtracking, the solution will be wrong if we don't store these integrated terms and do conditioning. Tracking the extra sequence required some extra work, so for the initial release I used the approximation. The option is off by default to make sure adaptive solvers' solutions are correct. Undoubtedly, using the approximation could result in a drop in order if the inner loop is not accurate enough. That said, this issue will be resolved when I merge in the faster data structures that record the integrated terms with proper conditioning.

The implementation of BrownianPath  is actually not based on the Brownian tree algorithm we proposed. Rather, it simply stores all the queries for speed. The C++ implementation is based on std::map and has O(log n) insertion and Brownian bridge query. I had some plans for using a splay-tree for this, since queries by the solver are usually quite close to each other, but that's not near the top of my priority list for now.

The Brownian tree algorithm is implemented in a separate class and the C++-based implementation can be imported as from torchsde.brownian_lib import BrownianTree. Internally, it caches queries at an intermediate level of the tree, so search doesn't always start from the root. This speeds things up by quite a bit and gives users the control to trade in memory for speed in a fine-grained manner.