Skip to content

Instantly share code, notes, and snippets.

@ChrisRackauckas
Last active September 28, 2022 20:49
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChrisRackauckas/6a03e7b151c86b32d74b41af54d495c6 to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/6a03e7b151c86b32d74b41af54d495c6 to your computer and use it in GitHub Desktop.
torchsde vs DifferentialEquations.jl / DiffEqFlux.jl (Julia) benchmarks

torchsde vs DifferentialEquations.jl / DiffEqFlux.jl (Julia)

This example is a 4-dimensional geometric brownian motion. The code for the torchsde version is pulled directly from the torchsde README so that it would be a fair comparison against the author's own code. The only change to that example is the addition of a dt choice so that the simulation method and time step matches between the two different programs.

The SDE is solved 100 times. The summary of the results is as follows:

  • torchsde: 1.87 seconds
  • DifferentialEquations.jl: 0.00115 seconds

This demonstrates a 1,600x performance difference in favor of Julia on the Python library's README example. Further testing against torchsde was not able to be completed because of these performance issues.

Note about regularity

We note that the performance difference in the context of neural SDEs is likely smaller due to the ability to time spent in matrix multiplication kernels. However, given that full SDE training examples like demonstrated here generally take about a minute, we still highly expect a major performance difference but currently do not have the compute time to run a full demonstration.

using StochasticDiffEq, StaticArrays
const μ = 0.5ones(4)
const σ = 0.1ones(4)
f(du,u,p,t) = du .= μ .* u
g(du,u,p,t) = du .= σ .* u
u0 = 0.1ones(4)
tspan = (0.0,1.0)
saveat = range(0.0,1.0,length=20)
prob = SDEProblem(f,g,u0,tspan)
@time for i in 1:100
sol = solve(prob,SRIW1(),adaptive=false,dt=saveat[2])
end
#0.001154 seconds (16.50 k allocations: 1.109 MiB)
#0.001162 seconds (16.50 k allocations: 1.109 MiB)
#0.001148 seconds (16.50 k allocations: 1.109 MiB)
import torch
from torchsde import sdeint
class SDE(torch.nn.Module):
def __init__(self, mu, sigma):
super().__init__()
self.noise_type="diagonal"
self.sde_type = "ito"
self.mu = mu
self.sigma = sigma
def f(self, t, y):
return self.mu * y
def g(self, t, y):
return self.sigma * y
batch_size, d, m = 4, 1, 1 # State dimension d, Brownian motion dimension m.
geometric_bm = SDE(mu=0.5, sigma=1)
y0 = torch.zeros(batch_size, d).fill_(0.1) # Initial state.
ts = torch.linspace(0, 1, 20)
ys = sdeint(geometric_bm, y0, ts)
def time_func():
ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': False})
timeit.Timer(time_func).timeit(number=100)
# 1.8681289999999535 seconds
# 1.8695188000001508 seconds
@lxuechen
Copy link

lxuechen commented Aug 4, 2020

Thanks for the update!

I would be cautious with saying only your application are real: there are probably orders of magnitude more people doing mathematical finance, model-informed drug development, and systems biology with SDEs than training neural SDEs for image processing (at least right now), and those disciplines naturally arrive at more heterogeneous models that cannot always be expressed in matmuls.

I think there might have been a bit of miscommunication here possibly due to my specific wording. I agree that there definitely are important and high-impact real problems which might not require too many large matmuls, and never have I questioned this. The point I wanted to make was that the problems we intend to concentrate on at the moment are those belonging to the subset of real problems which are also "matmul heavy".

I don't think it's an issue to point out that this can be an issue for someone just trying to pick up torchsde as a general-purpose SDE solver, even though it is outside of the realm you generally focus on. I did this little bit of benchmarking because some particle physicists were curious how it would benchmark on 3 complex SDE systems (so 6 real SDEs), meaning that this benchmark is quite close to their use case and is a notable result.

I agree. At this moment, torchsde isn't a fully polished general-purpose SDE solver library yet, and there are many aspects of potential improvement. That being said, I did try to point this out with the statement near the bottom of REAME.md, stating that this is still a research project. But in any case, thanks again for this benchmark!

@lxuechen
Copy link

lxuechen commented Aug 4, 2020

Speaking of jitting, here's an example of how one would jit the SDE:

import torch
from torchsde import sdeint
import timeit


class SDE(torch.nn.Module):

    def __init__(self, mu, sigma):
        super().__init__()
        self.noise_type = "diagonal"
        self.sde_type = "ito"

        self.mu = mu
        self.sigma = sigma

    @torch.jit.export
    def f(self, t, y):
        return self.mu * y

    @torch.jit.export
    def g(self, t, y):
        return self.sigma * y


batch_size, d, m = 4, 1, 1  # State dimension d, Brownian motion dimension m.
geometric_bm = SDE(mu=0.5, sigma=1)

# Works for torch==1.6.0.
geometric_bm = torch.jit.script(geometric_bm)

y0 = torch.zeros(batch_size, d).fill_(0.1)  # Initial state.
ts = torch.linspace(0, 1, 20)


def time_func():
    ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': False})


print(timeit.Timer(time_func).timeit(number=100))

For now scripting actually makes it slightly slower. Though in the ideal scenario when there's some control-flow or indexing in f and g, there might be an improvement.

@ChrisRackauckas
Copy link
Author

Thanks for the confirmation of the JIT script. Yeah I was trying it, but I think torchjit really works well on a different regime. Good to know I had that right.

BTW, what exactly is the trapezoidal approximation there? Is that for filling in gaps of the Brownian tree?

@lxuechen
Copy link

lxuechen commented Aug 6, 2020

The trapezoidal approximation estimates \int_{t_k}^{t_{k+1}} W_s \ds based solely on Brownian motion increments. It will be replaced soon with something faster that tracks two sequences of random variables like those RSwM algorithms you have.

For fixed-step solvers, we don't necessarily need to use an approximation, since the integrated term can be sampled exactly. For adaptive solvers, since there might be backtracking, the solution will be wrong if we don't store these integrated terms and do conditioning. Tracking the extra sequence required some extra work, so for the initial release I used the approximation. The option is off by default to make sure adaptive solvers' solutions are correct. Undoubtedly, using the approximation could result in a drop in order if the inner loop is not accurate enough. That said, this issue will be resolved when I merge in the faster data structures that record the integrated terms with proper conditioning.

The implementation of BrownianPath is actually not based on the Brownian tree algorithm we proposed. Rather, it simply stores all the queries for speed. The C++ implementation is based on std::map and has O(log n) insertion and Brownian bridge query. I had some plans for using a splay-tree for this, since queries by the solver are usually quite close to each other, but that's not near the top of my priority list for now.

The Brownian tree algorithm is implemented in a separate class and the C++-based implementation can be imported as from torchsde.brownian_lib import BrownianTree. Internally, it caches queries at an intermediate level of the tree, so search doesn't always start from the root. This speeds things up by quite a bit and gives users the control to trade in memory for speed in a fine-grained manner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment