Skip to content

Instantly share code, notes, and snippets.

@staticfloat
Created May 29, 2019 20:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save staticfloat/4905a9049e75f2ec7c7d10c4789a2239 to your computer and use it in GitHub Desktop.
Save staticfloat/4905a9049e75f2ec7c7d10c4789a2239 to your computer and use it in GitHub Desktop.
Threading speed test
using Base.Threads
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
# Remove `@threads` for non-threading timing
@inbounds @threads for j = 1:size(xs, 2)
# First, store column-wise maximum in the last element of `out`
out[end, j] = xs[end, j]
@inbounds for i = 1:(size(xs, 1) - 1)
out[end, j] = max(out[end, j], xs[i, j])
end
# Subtract the column-wise maximums to normalize, take exp()
# out .= exp(xs .- out[end, :])
@inbounds for i = 1:size(out, 1)
out[i, j] = exp(xs[i, j] - out[end, j])
end
# Normalize by sum of the entire thing
# out ./= sum(out, 1)
s = T(0)
@inbounds for i = 1:size(out, 1)
s += out[i, j]
end
@inbounds for i = 1:size(out, 1)
out[i, j] /= s
end
end
return out
end
using BenchmarkTools, JLD2, Base.Threads
include("softmax.jl")
times = Dict()
for N in (10, 100, 1000)
for batch_size in 2 .^ (1:8)
@show (N, batch_size)
x = randn(Float32, N, batch_size)
y = Array{Float32}(undef, N, batch_size)
times[(N, batch_size)] = @benchmark softmax!($y, $x)
end
end
JLD2.@save "$(nthreads())_threads.jld2" times
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment