Skip to content

Instantly share code, notes, and snippets.

@MilesCranmer
Last active April 6, 2023 01:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MilesCranmer/a8277006c6e6411ddfa9a28abdc4342b to your computer and use it in GitHub Desktop.
Save MilesCranmer/a8277006c6e6411ddfa9a28abdc4342b to your computer and use it in GitHub Desktop.
Compare forward-mode and reverse-mode differentiation over parameter #
using BenchmarkTools
using ForwardDiff
using ReverseDiff
using Random
using Plots
using Statistics: mean, quantile, std
using Measurements
using Printf: @sprintf
using Colors
# Okabe & Ito colors:
colors = [
RGB(0 / 255, 114 / 255, 178 / 255), # blue
RGB(230 / 255, 159 / 255, 0 / 255), # orange
RGB(0 / 255, 158 / 255, 115 / 255), # green
RGB(204 / 255, 121 / 255, 167 / 255), # reddish purple
RGB(86 / 255, 180 / 255, 233 / 255), # sky blue
RGB(213 / 255, 94 / 255, 0 / 255), # vermillion
RGB(240 / 255, 228 / 255, 66 / 255), # yellow
]
# Specialize number of loops (just in case)
function _f(x, params, ::Val{n}) where {n}
for i = 1:n
x = @. cos(x + params[i])
end
return x
end
function f(x, params)
return _f(x, params, Val(length(params)))
end
const suite = BenchmarkGroup()
suite["forward"] = BenchmarkGroup()
suite["reverse"] = BenchmarkGroup()
function forward(x, params)
return ForwardDiff.gradient(p -> sum(f(x, p)), params)
end
function reverse(x, params)
return ReverseDiff.gradient(p -> sum(f(x, p)), params)
end
# Warmup:
forward([1.0], [1.0])
reverse([1.0], [1.0])
all_num_params = [round(Int, 10^l) for l = 0.0:0.25:3.0]
all_num_x = [round(Int, 10^l) for l = 0:3]
for num_params in all_num_params, num_x in all_num_x
# Get benchmark information:
Random.seed!(0)
if !haskey(suite["forward"], num_params)
suite["forward"][num_params] = BenchmarkGroup()
suite["reverse"][num_params] = BenchmarkGroup()
end
# At num_params = 1, we want num_evals=100.
# At num_params = 1000, we want num_evals=10.
# Should be linear in log-space.
num_evals_log = 2.0 - log10(num_params) / 3.0
num_evals = round(Int, 10^num_evals_log)
suite["forward"][num_params][num_x] =
@benchmarkable $forward(x, params) evals = num_evals samples = 1000 setup =
(x = rand($num_x) .* 6.28; params = rand($num_params) .* 6.28)
suite["reverse"][num_params][num_x] =
@benchmarkable $reverse(x, params) evals = num_evals samples = 1000 setup =
(x = rand($num_x) .* 6.28; params = rand($num_params) .* 6.28)
end
res = run(suite, verbose = true)
res_matrix =
[res[k][np][nx] for np in all_num_params, nx in all_num_x, k in ["forward", "reverse"]]
res_agg = (x -> mean(log10.(x.times)) ± std(log10.(x.times))).(res_matrix)
ratio_forward_to_reverse = res_agg[:, :, 1] .- res_agg[:, :, 2]
# Set the backend to use with the desired resolution
gr(size = (600, 400), dpi = 300)
# Line plot, with x-axis the number of params, and y-axis the ratio.
# Different lines for each x.
# We also want to have log-scale on both axes.
# We also want dpi of 300
for (ix, nx) in enumerate(all_num_x)
plotter = ix == 1 ? plot : plot!
# Plot these values (Measurement{Float64}, so they automatically get error bars)
# We want to color the error bars the same as the line!
plotter(
all_num_params,
ratio_forward_to_reverse[:, ix, :],
label = "nₓ = $nx",
xscale = :log10,
color = colors[ix],
markerstrokecolor = colors[ix],
)
end
# Plot y=1 line: (put in background)
# Put in background:
plot!(
all_num_params,
zeros(size(all_num_params)),
label = "",
color = :black,
alpha = 0.1,
line = (:dash, 2.0),
)
# Legend:
plot!(legend = :topleft)
# Add text to bottom half: "Forward better":
annotate!(3, -0.2, text("Forward better", :black, 8))
annotate!(3, 0.2, text("Reverse better", :black, 8))
# Explain errors:
annotate!(10^2.5, -1.3, text("Errors show 1σ", :black, 8))
# Label x-axis:
xlabel!("Number of parameters")
# Label y-axis:
ylabel!("Δt[ForwardDiff] / Δt[ReverseDiff]")
# x-ticks at 10^0, 10^1, 10^2, 10^3:
xticks!(10.0 .^ (0:3))
# Set y-ticks manually:
new_ticks = [0.1, 1.0, 10.0]
yticks!((log10.(new_ticks), string.(new_ticks)))
# Save with dpi of 300:
savefig("forward_vs_reverse.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment