Skip to content

Instantly share code, notes, and snippets.

@nlw0
Created April 17, 2022 10:11
Show Gist options
  • Save nlw0/80eca5f6b873e9b75c760b8635c0a593 to your computer and use it in GitHub Desktop.
Save nlw0/80eca5f6b873e9b75c760b8635c0a593 to your computer and use it in GitHub Desktop.
Test speed of julia `sum` compared to a for-loop over Float32 values
using BenchmarkTools
using Statistics
using GLMakie
function simpsum(a)
y = zero(eltype(a))
ll = length(a)
@inbounds @simd for i in 1:ll #eachindex(a)
y += a[i]
end
y
end
pwsum(a, blksize) = Base.mapreduce_impl(identity, +, a, 1, length(a), blksize)
function timetest(Npt, blksize)
mydata = rand(Float32, Npt)
tb = @benchmark simpsum($mydata)
tc = @benchmark pwsum($mydata, $blksize)
[median(tb.times), median(tc.times)]
end
function acctest(Npt, blksize, Niter=333)
data = map(1:Niter) do _
# mydata = rand(Float32, Npt)
mydata = randn(Float32, Npt)
mydata2 = Float64.(mydata)
a = sum(mydata2)
b = simpsum(mydata)
c = pwsum(mydata, blksize)
simp_err = (a-b)/a
accu_err = (a-c)/a
[simp_err, accu_err]
end
data = reduce(hcat, data)
median(abs.(data), dims=2)
# std(data, dims=2)
end
blksize = 1024
ll = 2 .^ (3:1:23)
accu = reduce(hcat, acctest(Npt, blksize) for Npt in ll)
times = reduce(hcat, timetest(Npt, blksize) for Npt in ll)
f = Figure()
ax = Axis(
f[1,1],
title="Median of absolute relative error",
xticks=(1:length(ll), ["2^$(n)" for n in 3:1:23]),
# yticks=(1:length(ll), ["10^$(n)" for n in 4:2:22]),
yscale=log10,
)
scatterlines!(ax,accu[1,:], label="simp err")
scatterlines!(ax,accu[2,:], label="pwsum err")
axislegend(ax, position=:lt)
ax2 = Axis(
f[2,1],
title="Running time",
xticks=(1:length(ll), ["2^$(n)" for n in 3:1:23]),
# yticks=(1:length(ll), ["10^$(n)" for n in 4:2:22]),
yscale=log10,
)
scatterlines!(ax2,times[1,:], label="simp time")
scatterlines!(ax2,times[2,:], label="pwsum time")
axislegend(ax2,position=:lt)
ax3 = Axis(
f[3,1],
title="Relative unning time",
xticks=(1:length(ll), ["2^$(n)" for n in 3:1:23]),
ylim=(0,5),
)
barplot!(ax3,times[2,:]./times[1,:], label="pwsum time / simp time")
axislegend(ax3,position=:lt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment