Last active
March 13, 2019 03:50
-
-
Save GregVernon/2525ef489120f9f030ac8172b3e0c66f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base.Threads | |
import LinearAlgebra | |
import StaticArrays | |
import BenchmarkTools | |
function NestedIteration_MatrixVector(M,V,R) | |
N1 = size(M,1) | |
nMatrix = size(M,3) | |
nVector = size(V,2) | |
for i = 1:nMatrix # Iterate through the matrices | |
for j = 1:nVector # Iterate through the vectors | |
@inbounds R[:,j,i] = M[:,:,i] * V[:,j] | |
end | |
end | |
return R | |
end | |
function Iteration_MatrixMatrix(M,V,R) | |
N1 = size(M,1) | |
nMatrix = size(M,3) | |
nVector = size(V,2) | |
for i = 1:nMatrix # Iterate through the matrices | |
@inbounds R[:,:,i] = M[:,:,i] * V # Matrix-Matrix multiplication is equivalent to naive approach | |
end | |
return R | |
end | |
function Threaded_mul!(M,V,R) | |
size(R) == (length(V),length(M)) || throw(BoundsError()) | |
Threads.@threads for i = 1:length(M) | |
for j = 1:length(V) | |
@inbounds LinearAlgebra.mul!(R[j,i], M[i], V[j]) | |
end | |
end | |
return R | |
end | |
function Threaded_StaticArrays(M,V,R) | |
size(R) == (length(V),length(M)) || throw(BoundsError()) | |
Threads.@threads for i = 1:length(M) | |
for j = 1:length(V) | |
@inbounds R[j,i] = M[i] * V[j] | |
end | |
end | |
return R | |
end | |
function BLAS_GEMM(M,V,R) | |
R = M*V | |
return R | |
end | |
function Threaded_List_mul!(M,V,R) | |
# size(R) == (length(V),length(M)) || throw(BoundsError()) | |
Threads.@threads for i = 1:length(M) | |
@inbounds LinearAlgebra.mul!(R[i], M[i], V) | |
end | |
return R | |
end | |
function doBenchmarks(N1,N2,nMatrix,nVector) | |
## My original functions | |
M = rand(N1,N2,nMatrix) | |
V = rand(N2,nVector) | |
R = zeros(N1,nVector,nMatrix) | |
BenchmarkTools.@btime NestedIteration_MatrixVector($M,$V,$R) | |
BenchmarkTools.@btime Iteration_MatrixMatrix($M,$V,$R) | |
## Threaded mul! | |
M = [rand(N1,N2) for i=1:nMatrix] | |
V = [rand(N2) for j=1:nVector] | |
R = [zeros(N1) for j = 1:nVector, i = 1:nMatrix] | |
BenchmarkTools.@btime Threaded_mul!($M,$V,$R) | |
## Threaded StaticArrays | |
M = rand(StaticArrays.SMatrix{N1,N2,Float64,N1*N2}, nMatrix) | |
V = rand(StaticArrays.SVector{N2,Float64}, nVector) | |
R = StaticArrays.Matrix{StaticArrays.SVector{N1, Float64}}(undef, nVector, nMatrix) | |
BenchmarkTools.@btime Threaded_StaticArrays($M,$V,$R) | |
## BLAS GEMM | |
LinearAlgebra.BLAS.set_num_threads(Threads.nthreads()) | |
M = vcat([rand(N1,N2) for i=1:nMatrix]...) | |
V = rand(N2,nVector) | |
R = zeros(N1,nVector,nMatrix) | |
BenchmarkTools.@btime BLAS_GEMM($M,$V,$R) | |
## Threaded mul! on a list of Arrays | |
M = [rand(N1,N2) for i=1:nMatrix] | |
V = rand(N2,nVector) | |
R = [zeros(N1,nVector) for i = 1:nMatrix] | |
BenchmarkTools.@btime Threaded_List_mul!($M,$V,$R) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment