Instantly share code, notes, and snippets.

# ChrisRackauckas/diffeq_vs_torchsde.md

Last active September 28, 2022 20:49
Show Gist options
• Save ChrisRackauckas/6a03e7b151c86b32d74b41af54d495c6 to your computer and use it in GitHub Desktop.
torchsde vs DifferentialEquations.jl / DiffEqFlux.jl (Julia) benchmarks

# torchsde vs DifferentialEquations.jl / DiffEqFlux.jl (Julia)

This example is a 4-dimensional geometric brownian motion. The code for the torchsde version is pulled directly from the torchsde README so that it would be a fair comparison against the author's own code. The only change to that example is the addition of a dt choice so that the simulation method and time step matches between the two different programs.

The SDE is solved 100 times. The summary of the results is as follows:

• torchsde: 1.87 seconds
• DifferentialEquations.jl: 0.00115 seconds

This demonstrates a 1,600x performance difference in favor of Julia on the Python library's README example. Further testing against torchsde was not able to be completed because of these performance issues.

We note that the performance difference in the context of neural SDEs is likely smaller due to the ability to time spent in matrix multiplication kernels. However, given that full SDE training examples like demonstrated here generally take about a minute, we still highly expect a major performance difference but currently do not have the compute time to run a full demonstration.

This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using StochasticDiffEq, StaticArrays const μ = 0.5ones(4) const σ = 0.1ones(4) f(du,u,p,t) = du .= μ .* u g(du,u,p,t) = du .= σ .* u u0 = 0.1ones(4) tspan = (0.0,1.0) saveat = range(0.0,1.0,length=20) prob = SDEProblem(f,g,u0,tspan) @time for i in 1:100 sol = solve(prob,SRIW1(),adaptive=false,dt=saveat[2]) end #0.001154 seconds (16.50 k allocations: 1.109 MiB) #0.001162 seconds (16.50 k allocations: 1.109 MiB) #0.001148 seconds (16.50 k allocations: 1.109 MiB)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import torch from torchsde import sdeint class SDE(torch.nn.Module): def __init__(self, mu, sigma): super().__init__() self.noise_type="diagonal" self.sde_type = "ito" self.mu = mu self.sigma = sigma def f(self, t, y): return self.mu * y def g(self, t, y): return self.sigma * y batch_size, d, m = 4, 1, 1 # State dimension d, Brownian motion dimension m. geometric_bm = SDE(mu=0.5, sigma=1) y0 = torch.zeros(batch_size, d).fill_(0.1) # Initial state. ts = torch.linspace(0, 1, 20) ys = sdeint(geometric_bm, y0, ts) def time_func(): ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': False}) timeit.Timer(time_func).timeit(number=100) # 1.8681289999999535 seconds # 1.8695188000001508 seconds

### lxuechen commented Aug 4, 2020

Thanks for the update!

I would be cautious with saying only your application are real: there are probably orders of magnitude more people doing mathematical finance, model-informed drug development, and systems biology with SDEs than training neural SDEs for image processing (at least right now), and those disciplines naturally arrive at more heterogeneous models that cannot always be expressed in matmuls.

I think there might have been a bit of miscommunication here possibly due to my specific wording. I agree that there definitely are important and high-impact real problems which might not require too many large matmuls, and never have I questioned this. The point I wanted to make was that the problems we intend to concentrate on at the moment are those belonging to the subset of real problems which are also "matmul heavy".

I don't think it's an issue to point out that this can be an issue for someone just trying to pick up torchsde as a general-purpose SDE solver, even though it is outside of the realm you generally focus on. I did this little bit of benchmarking because some particle physicists were curious how it would benchmark on 3 complex SDE systems (so 6 real SDEs), meaning that this benchmark is quite close to their use case and is a notable result.

I agree. At this moment, torchsde isn't a fully polished general-purpose SDE solver library yet, and there are many aspects of potential improvement. That being said, I did try to point this out with the statement near the bottom of REAME.md, stating that this is still a research project. But in any case, thanks again for this benchmark!

### lxuechen commented Aug 4, 2020

Speaking of jitting, here's an example of how one would jit the SDE:

import torch
from torchsde import sdeint
import timeit

class SDE(torch.nn.Module):

def __init__(self, mu, sigma):
super().__init__()
self.noise_type = "diagonal"
self.sde_type = "ito"

self.mu = mu
self.sigma = sigma

@torch.jit.export
def f(self, t, y):
return self.mu * y

@torch.jit.export
def g(self, t, y):
return self.sigma * y

batch_size, d, m = 4, 1, 1  # State dimension d, Brownian motion dimension m.
geometric_bm = SDE(mu=0.5, sigma=1)

# Works for torch==1.6.0.
geometric_bm = torch.jit.script(geometric_bm)

y0 = torch.zeros(batch_size, d).fill_(0.1)  # Initial state.
ts = torch.linspace(0, 1, 20)

def time_func():
ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': False})

print(timeit.Timer(time_func).timeit(number=100))

For now scripting actually makes it slightly slower. Though in the ideal scenario when there's some control-flow or indexing in f and g, there might be an improvement.

### ChrisRackauckas commented Aug 4, 2020

Thanks for the confirmation of the JIT script. Yeah I was trying it, but I think torchjit really works well on a different regime. Good to know I had that right.

BTW, what exactly is the trapezoidal approximation there? Is that for filling in gaps of the Brownian tree?

The implementation of BrownianPath  is actually not based on the Brownian tree algorithm we proposed. Rather, it simply stores all the queries for speed. The C++ implementation is based on std::map and has O(log n) insertion and Brownian bridge query. I had some plans for using a splay-tree for this, since queries by the solver are usually quite close to each other, but that's not near the top of my priority list for now.
The Brownian tree algorithm is implemented in a separate class and the C++-based implementation can be imported as from torchsde.brownian_lib import BrownianTree. Internally, it caches queries at an intermediate level of the tree, so search doesn't always start from the root. This speeds things up by quite a bit and gives users the control to trade in memory for speed in a fine-grained manner.