Skip to content

Instantly share code, notes, and snippets.

@GregVernon
Last active March 13, 2019 03:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GregVernon/2525ef489120f9f030ac8172b3e0c66f to your computer and use it in GitHub Desktop.
Save GregVernon/2525ef489120f9f030ac8172b3e0c66f to your computer and use it in GitHub Desktop.
import Base.Threads
import LinearAlgebra
import StaticArrays
import BenchmarkTools
function NestedIteration_MatrixVector(M,V,R)
N1 = size(M,1)
nMatrix = size(M,3)
nVector = size(V,2)
for i = 1:nMatrix # Iterate through the matrices
for j = 1:nVector # Iterate through the vectors
@inbounds R[:,j,i] = M[:,:,i] * V[:,j]
end
end
return R
end
function Iteration_MatrixMatrix(M,V,R)
N1 = size(M,1)
nMatrix = size(M,3)
nVector = size(V,2)
for i = 1:nMatrix # Iterate through the matrices
@inbounds R[:,:,i] = M[:,:,i] * V # Matrix-Matrix multiplication is equivalent to naive approach
end
return R
end
function Threaded_mul!(M,V,R)
size(R) == (length(V),length(M)) || throw(BoundsError())
Threads.@threads for i = 1:length(M)
for j = 1:length(V)
@inbounds LinearAlgebra.mul!(R[j,i], M[i], V[j])
end
end
return R
end
function Threaded_StaticArrays(M,V,R)
size(R) == (length(V),length(M)) || throw(BoundsError())
Threads.@threads for i = 1:length(M)
for j = 1:length(V)
@inbounds R[j,i] = M[i] * V[j]
end
end
return R
end
function BLAS_GEMM(M,V,R)
R = M*V
return R
end
function Threaded_List_mul!(M,V,R)
# size(R) == (length(V),length(M)) || throw(BoundsError())
Threads.@threads for i = 1:length(M)
@inbounds LinearAlgebra.mul!(R[i], M[i], V)
end
return R
end
function doBenchmarks(N1,N2,nMatrix,nVector)
## My original functions
M = rand(N1,N2,nMatrix)
V = rand(N2,nVector)
R = zeros(N1,nVector,nMatrix)
BenchmarkTools.@btime NestedIteration_MatrixVector($M,$V,$R)
BenchmarkTools.@btime Iteration_MatrixMatrix($M,$V,$R)
## Threaded mul!
M = [rand(N1,N2) for i=1:nMatrix]
V = [rand(N2) for j=1:nVector]
R = [zeros(N1) for j = 1:nVector, i = 1:nMatrix]
BenchmarkTools.@btime Threaded_mul!($M,$V,$R)
## Threaded StaticArrays
M = rand(StaticArrays.SMatrix{N1,N2,Float64,N1*N2}, nMatrix)
V = rand(StaticArrays.SVector{N2,Float64}, nVector)
R = StaticArrays.Matrix{StaticArrays.SVector{N1, Float64}}(undef, nVector, nMatrix)
BenchmarkTools.@btime Threaded_StaticArrays($M,$V,$R)
## BLAS GEMM
LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
M = vcat([rand(N1,N2) for i=1:nMatrix]...)
V = rand(N2,nVector)
R = zeros(N1,nVector,nMatrix)
BenchmarkTools.@btime BLAS_GEMM($M,$V,$R)
## Threaded mul! on a list of Arrays
M = [rand(N1,N2) for i=1:nMatrix]
V = rand(N2,nVector)
R = [zeros(N1,nVector) for i = 1:nMatrix]
BenchmarkTools.@btime Threaded_List_mul!($M,$V,$R)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment