Skip to content

Instantly share code, notes, and snippets.

@ForceBru
Last active Aug 10, 2022
Embed
What would you like to do?
Benchmark of Julia autodiff

Benchmark of several Julia autodiff packages (including Symbolics.jl)

The goal is to differentiate a log-likelihood function - the workhorse of probability theory, mathematical statistics and machine learning.

Here, it's the log-likelihood of a Gaussian mixture model:

normal_pdf(x::Real, mean::Real, var::Real) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
  
    sum(
        sum(
            weight * normal_pdf(x, mean, std^2)
            for (weight, mean, std) in zip(weights, means, stds)
        ) |> log
        for x in data
    )
end

Package versions

Julia 1.8.0-rc3

  [6e4b80f9] BenchmarkTools v1.3.1
  [f6369f11] ForwardDiff v0.10.32
  [37e2e3b7] ReverseDiff v1.14.1
  [0c5d862f] Symbolics v4.10.3
  [e88e6eb3] Zygote v0.6.43

Benchmark results

Setup:

  • 500 data samples
  • 4 mixture components: 4 weights, 4 means, 4 standard deviations => 12 parameters total

Mean time (10000 samples with 1 evaluation):

  1. JAX (Python): 61.308 μs ± 5.368 μs
  2. Zygote.jl: 271.229 μs ± 714.442 μs
  3. ForwardDiff.jl: 515.443 μs ± 665.059 μs
  4. Symbolics.jl (precomputed gradient, mutating function): 809.821 μs ± 273.313 μs
  5. ReverseDiff.jl: 1.633 ms ± 74.832 μs
import Pkg; Pkg.status()
@info "Loading packages..."
import Random, DelimitedFiles
import ForwardDiff, ReverseDiff, Zygote, Symbolics
using BenchmarkTools
const AV = AbstractVector{T} where T
# ========== Objective function ==========
normal_pdf(x::Real, mean::Real, var::Real) =
exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)
mixture_pdf(x::Real, weights::AV{<:Real}, means::AV{<:Real}, vars::AV{<:Real}) =
sum(
w * normal_pdf(x, mean, var)
for (w, mean, var) in zip(weights, means, vars)
)
normal_pdf(x, mean, var) =
exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)
function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
K = length(params) ÷ 3
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
mat = normal_pdf.(data, means', stds' .^2) # (N, K)
#@show size(mat)
sum(mat .* weights', dims=2) .|> log |> sum
# sum(
# sum(
# weight * normal_pdf(x, mean, std^2)
# for (weight, mean, std) in zip(weights, means, stds)
# ) |> log
# for x in data
# )
end
function generate_gradient(out_fname::AbstractString, K::Integer)
@assert K > 0
Symbolics.@variables x ws[1:K] mus[1:K] stds[1:K]
args=[x, ws, mus, stds]
expr = Symbolics.gradient(
mixture_pdf(x, ws, mus, collect(stds) .^2) |> log,
[ws; mus; stds]
)
fn, fn_mut = Symbolics.build_function(expr, args...)
write(out_fname, string(fn_mut))
end
# ========== Gradient with Symbolics.jl ==========
@info "Generating gradient functions..."
GRAD_FNS = Union{Nothing, Function}[nothing]
for K in 2:5
fname = "grad_$K.jl"
@show generate_gradient(fname, K)
push!(GRAD_FNS, include(fname))
end
function my_gradient!(out::AV{<:Real}, tmp::AV{<:Real}, xs::AV{<:Real}, params::AV{<:Real})
K = length(params) ÷ 3
grad! = GRAD_FNS[K]
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
out .= 0
for x in xs
grad!(tmp, x, weights, means, stds)
out .+= tmp
end
end
# ========== Benchmark setup ==========
SEED = 42
N_SAMPLES = 500
N_COMPONENTS = 4
rnd = Random.MersenneTwister(SEED)
data = randn(rnd, N_SAMPLES)
params0 = [rand(rnd, N_COMPONENTS); randn(rnd, N_COMPONENTS); 2rand(rnd, N_COMPONENTS)]
DelimitedFiles.writedlm("gen_data.csv", data, ',')
DelimitedFiles.writedlm("gen_params0.csv", params0, ',')
objective = params -> mixture_loglikelihood(params, data)
@show params0
@show objective(params0)
@info "Settings" SEED N_SAMPLES N_COMPONENTS length(params0)
# ========== Actual benchmarks ==========
@info "Computing gradient w/ Symbolics"
let
grad_storage = similar(params0)
tmp = similar(params0)
# 1. Compile
my_gradient!(grad_storage, tmp, data, params0)
# 2. Benchmark
trial = run(@benchmarkable $my_gradient!($grad_storage, $tmp, $data, $params0) samples=10_000 evals=1 seconds=60)
show(stdout, MIME("text/plain"), trial)
println()
@show grad_storage
end
@info "Computing gradient w/ ForwardDiff"
let
grad_storage = similar(params0)
cfg_grad = ForwardDiff.GradientConfig(objective, params0, ForwardDiff.Chunk{length(params0)}())
# 1. Compile
ForwardDiff.gradient!(grad_storage, objective, params0, cfg_grad)
# 2. Benchmark
trial = run(@benchmarkable ForwardDiff.gradient!($grad_storage, $objective, $params0, $cfg_grad) samples=10_000 evals=1 seconds=60)
show(stdout, MIME("text/plain"), trial)
println()
@show grad_storage
end
@info "Computing gradient w/ ReverseDiff"
let
grad_storage = similar(params0)
objective_tape = ReverseDiff.GradientTape(objective, params0) |> ReverseDiff.compile
# 1. Compile
ReverseDiff.gradient!(grad_storage, objective_tape, params0)
# 2. Benchmark
trial = run(@benchmarkable ReverseDiff.gradient!($grad_storage, $objective_tape, $params0) samples=10_000 evals=1 seconds=60)
show(stdout, MIME("text/plain"), trial)
println()
@show grad_storage
end
@info "Computing gradient w/ Zygote reverse"
let
# 1. Compile
grad_storage = Zygote.gradient(objective, params0)
# 2. Benchmark
trial = run(@benchmarkable Zygote.gradient($objective, $params0) samples=10_000 evals=1 seconds=60)
show(stdout, MIME("text/plain"), trial)
println()
@show grad_storage
end
import timeit
import numpy as np
import jax
import jax.numpy as jnp
# Enable float64 support
jax.config.update("jax_enable_x64", True)
@jax.jit
def normal_pdf(data, mean, var):
return jnp.exp(-(data - mean)**2 / (2 * var)) / jnp.sqrt(2 * jnp.pi * var)
@jax.jit
def mixture_loglikelihood(params: jnp.ndarray, data: jnp.ndarray) -> float:
K = len(params) // 3
weights, means, stds = params[:K], params[K:2*K], params[2*K:]
mat = normal_pdf(data, means.T, stds.T**2) # (N, K)
return jnp.log((mat * weights.T).sum(1)).sum()
data = np.loadtxt("gen_data.csv").flatten()[:, None]
params0 = np.loadtxt("gen_params0.csv").flatten()
params0 = jnp.array(params0)
objective = lambda params: mixture_loglikelihood(params, data)
print(objective(params0))
# Output: -443.40397372007186
the_grad = jax.jit(jax.grad(objective))
print(the_grad(params0))
# Output: [289.73084956 199.27559525 236.68945778 292.06123402 -9.42979939
# 26.72229565 -1.91803555 37.9874909 -24.09562015 -13.93568733
# -38.00044666 12.87712892]
# ========== Benchmark ==========
N_SAMPLES, N_EVALS = 10_000, 1 # like in Julia
bench_secs = timeit.repeat(
"the_grad(params0)", globals={'the_grad': the_grad, 'params0': params0},
repeat=N_SAMPLES, number=N_EVALS
)
bench_mus = 1_000_000 * np.array(bench_secs)
print(f"{bench_mus.mean():.3f} μs ± {bench_mus.std():.3f} μs ({N_SAMPLES} samples with {N_EVALS} evaluation)")
# 61.308 μs ± 5.368 μs (10000 samples with 1 evaluation)
$ julia-1.8 --project code.jl
Status `~/test/autodiff_bench/Project.toml`
[6e4b80f9] BenchmarkTools v1.3.1
[7da242da] Enzyme v0.10.4
[f6369f11] ForwardDiff v0.10.32
[37e2e3b7] ReverseDiff v1.14.1
[0c5d862f] Symbolics v4.10.3
[e88e6eb3] Zygote v0.6.43
[ Info: Loading packages...
[ Info: Generating gradient functions...
generate_gradient(fname, K) = 5517
generate_gradient(fname, K) = 9708
generate_gradient(fname, K) = 15090
generate_gradient(fname, K) = 21660
params0 = [0.25733304995705586, 0.4635056170085754, 0.5285451129509773, 0.7120981447127772, 0.835601264145011, -1.4646785862195637, 0.24736086263101278, -0.21967358320549735, 1.0624643704713206, 1.628664511492019, 1.8530572439128092, 0.6276756477143253]
objective(params0) = -443.4039737200718
┌ Info: Settings
│ SEED = 42
│ N_SAMPLES = 500
│ N_COMPONENTS = 4
└ length(params0) = 12
[ Info: Computing gradient w/ Symbolics
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 767.272 μs … 13.003 ms ┊ GC (min … max): 0.00% … 93.99%
Time (median): 776.226 μs ┊ GC (median): 0.00%
Time (mean ± σ): 809.821 μs ± 273.313 μs ┊ GC (mean ± σ): 1.11% ± 3.11%
█▇▅▂▂▃▅▄▄▂▃▂▂▂▁▁▁▁▁ ▄▂▄ ▂▁ ▂
█████████████████████████▇▇████▇▇▇███▇▇▇▇▇▇▅▆▆▅▅▆▅▅▅▅▄▅▅▅▃▆▄▅ █
767 μs Histogram: log(frequency) by time 1 ms <
Memory estimate: 78.12 KiB, allocs estimate: 2000.
grad_storage = [289.7308495620468, 199.2755952498574, 236.68945777568766, 292.0612340227955, -9.429799389881461, 26.722295646439054, -1.918035554675224, 37.987490895733956, -24.095620148778284, -13.935687326484123, -38.00044665702688, 12.877128915271317]
[ Info: Computing gradient w/ ForwardDiff
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 249.723 μs … 13.928 ms ┊ GC (min … max): 0.00% … 97.37%
Time (median): 450.822 μs ┊ GC (median): 0.00%
Time (mean ± σ): 515.443 μs ± 665.059 μs ┊ GC (mean ± σ): 10.82% ± 7.95%
▃▂ ▂▅██▇▆▅▄▅▅▅▅▄▃▂▁ ▁ ▂
██▇▆▇▆▆▇▇▄▅▄▅▁▃▁▅▅▅▆▄▁▁▆█████████████████████▇▇▇▇▇▇▇▆▇▇▇▇▇▇▆▆ █
250 μs Histogram: log(frequency) by time 688 μs <
Memory estimate: 508.25 KiB, allocs estimate: 13.
grad_storage = [289.7308495620468, 199.2755952498574, 236.68945777568769, 292.0612340227955, -9.429799389881461, 26.72229564643905, -1.9180355546752244, 37.98749089573394, -24.095620148778302, -13.935687326484116, -38.000446657026885, 12.8771289152713]
[ Info: Computing gradient w/ ReverseDiff
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 1.582 ms … 2.763 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.603 ms ┊ GC (median): 0.00%
Time (mean ± σ): 1.633 ms ± 74.832 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄▇█▇▅▄▃▄▃▄▃▃▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
███████████████████████████████▇█▇▇▆▆▇▇▅▆▇▆▆▆▆▅▅▆▆▅▄▆▆▄▆▅▆ █
1.58 ms Histogram: log(frequency) by time 1.96 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
grad_storage = [289.7308495620467, 199.27559524985728, 236.6894577756876, 292.0612340227955, -9.429799389881452, 26.722295646439047, -1.9180355546752244, 37.98749089573397, -24.095620148778274, -13.935687326484114, -38.00044665702692, 12.877128915271324]
[ Info: Computing gradient w/ Zygote reverse
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 171.294 μs … 19.198 ms ┊ GC (min … max): 0.00% … 98.27%
Time (median): 216.675 μs ┊ GC (median): 0.00%
Time (mean ± σ): 271.229 μs ± 714.442 μs ┊ GC (mean ± σ): 13.96% ± 5.25%
▁▃▁ ▁▃▂▅█▅▅▄▂▃▂▂▃▁ ▁ ▂▄▄▄▄▄▃▃▂▁ ▂
█████▆▆▆▄▄▂▃█████████████████▇▇▆█▆▆█████████████▇▆▇▆▇▆▆▆▆▆▄▅▄ █
171 μs Histogram: log(frequency) by time 326 μs <
Memory estimate: 231.40 KiB, allocs estimate: 210.
grad_storage = ([289.7308495620465, 199.27559524985728, 236.68945777568746, 292.06123402279576, -9.429799389881458, 26.722295646439065, -1.9180355546752224, 37.98749089573398, -24.095620148778284, -13.935687326484112, -38.00044665702688, 12.877128915271337],)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment