Skip to content

Instantly share code, notes, and snippets.

@antoine-levitt
Created August 21, 2019 07:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antoine-levitt/565750d1e7a323330e20e7b58e55c895 to your computer and use it in GitHub Desktop.
Save antoine-levitt/565750d1e7a323330e20e7b58e55c895 to your computer and use it in GitHub Desktop.
import LinearAlgebra.BLAS
const libblas = Base.libblas_name
const liblapack = Base.liblapack_name
import LinearAlgebra
import LinearAlgebra: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, checksquare, stride1, chkstride1, axpy!
import Libdl
for (gemm, elty) in
((:dgemm_,:Float64),
(:sgemm_,:Float32),
(:zgemm3m_,:ComplexF64),
(:cgemm3m_,:ComplexF32))
@eval begin
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function gemm!(transA::Char, transB::Char, alpha::($elty), A::AbstractVecOrMat{$elty}, B::AbstractVecOrMat{$elty}, beta::($elty), C::AbstractVecOrMat{$elty})
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
# error("gemm!: BLAS module requires contiguous matrix columns")
# end # should this be checked on every call?
m = size(A, transA == 'N' ? 1 : 2)
ka = size(A, transA == 'N' ? 2 : 1)
kb = size(B, transB == 'N' ? 1 : 2)
n = size(B, transB == 'N' ? 2 : 1)
if ka != kb || m != size(C,1) || n != size(C,2)
throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
end
chkstride1(A)
chkstride1(B)
chkstride1(C)
ccall((BLAS.@blasfunc($gemm), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}),
transA, transB, m, n,
ka, alpha, A, max(1,stride(A,2)),
B, max(1,stride(B,2)), beta, C,
max(1,stride(C,2)))
C
end
function gemm(transA::Char, transB::Char, alpha::($elty), A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty), similar(B, $elty, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1))))
end
function gemm(transA::Char, transB::Char, A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
end
end
for N in 2 .^(1:10)
A = randn(N,N) + im*randn(N,N)
B = randn(N,N) + im*randn(N,N)
C = randn(N,N) + im*randn(N,N)
println(N)
@btime gemm!('N', 'N', one(ComplexF64), $A, $B, zero(ComplexF64), $C)
@btime BLAS.gemm!('N', 'N', one(ComplexF64), $A, $B, zero(ComplexF64), $C)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment