Skip to content

Instantly share code, notes, and snippets.

@maxbennedich
Last active August 2, 2018 15:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxbennedich/1fcb5311ff225a48cf797c11bef4d6c5 to your computer and use it in GitHub Desktop.
Save maxbennedich/1fcb5311ff225a48cf797c11bef4d6c5 to your computer and use it in GitHub Desktop.
Matrix-vector multiplication benchmarks (Julia 0.7 version)
# Matrix-vector multiplication benchmarks (Julia 0.7 version)
using LinearAlgebra
using Statistics
using Printf
f(a, b) = a * b # Matrix-vector multiplication
#f(a, b) = log(1 + a^b) # Slow operation
function mul_broadcast!(M, v)
# For f = a*b, use M .*= v' (same performance)
M .= f.(M, v')
end
function mul_diagonal!(M, v)
rmul!(M, Diagonal(v))
end
function mul_loop!(M, v)
n,m = size(M)
for c = 1:m
x = v[c]
for r = 1:n
# For f = a*b, use M[r,c] *= x (same performance)
M[r,c] = f(M[r,c], x)
end
end
end
function mul_loop_inbounds!(M, v)
n,m = size(M)
@inbounds for c = 1:m
x = v[c]
for r = 1:n
M[r,c] = f(M[r,c], x)
end
end
end
function mul_parallel_loop!(M, v)
n,m = size(M)
@inbounds Threads.@threads for c = 1:m
x = v[c]
for r = 1:n
M[r,c] = f(M[r,c], x)
end
end
end
# Run expression 'ex' 'iterations' times and report median time and memory usage
macro timem(name, iterations, ex)
quote
t = [collect((@timed $(esc(ex)))[2:3]) for x=1:$iterations][2:end]
m = median(reduce(hcat, t), dims=2)
@printf("%s%5.0f ms (%d bytes)\n", $(esc(name)), m[1]*1e3, m[2])
end
end
m = 1000
n = 100000
@printf("Matrix size: %d x %d\n", m, n)
types = [Int16, Int64, Float16, Float32, Float64]
tests = [
(name = "Parallel for loop", fun = mul_parallel_loop!)
(name = "For loop", fun = mul_loop!)
(name = "For loop @inbounds", fun = mul_loop_inbounds!)
(name = "Broadcast", fun = mul_broadcast!)
(name = "Diagonal", fun = mul_diagonal!)]
width = maximum(map(t -> length(t.name), tests))
for dt in types
println("\n$dt:")
M = rand(dt, n, m)
v = rand(dt, m)
norms = map(t -> (A = copy(M); @timem rpad(t.name, width) 10 t.fun(A, v); norm(A)), tests)
if !all(diff(norms) .== 0) println("==> Norms differ: $norms") end
end
@maxbennedich
Copy link
Author

Sample output:

Matrix size: 1000 x 100000

Int16:
Parallel for loop    19 ms (48 bytes)
For loop             62 ms (0 bytes)
For loop @inbounds   16 ms (0 bytes)
Broadcast            63 ms (0 bytes)
Diagonal             48 ms (16 bytes)

Int64:
Parallel for loop    62 ms (48 bytes)
For loop             80 ms (0 bytes)
For loop @inbounds   73 ms (0 bytes)
Broadcast            86 ms (0 bytes)
Diagonal             90 ms (16 bytes)

Float16:
Parallel for loop   565 ms (48 bytes)
For loop           2691 ms (0 bytes)
For loop @inbounds 2668 ms (0 bytes)
Broadcast          2647 ms (0 bytes)
Diagonal           2648 ms (16 bytes)

Float32:
Parallel for loop    31 ms (48 bytes)
For loop             65 ms (0 bytes)
For loop @inbounds   34 ms (0 bytes)
Broadcast            66 ms (0 bytes)
Diagonal             55 ms (16 bytes)

Float64:
Parallel for loop    62 ms (48 bytes)
For loop             79 ms (0 bytes)
For loop @inbounds   68 ms (0 bytes)
Broadcast            88 ms (0 bytes)
Diagonal             89 ms (16 bytes)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment