Skip to content

Instantly share code, notes, and snippets.

@staticfloat
Last active September 16, 2022 15:21
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save staticfloat/ce81a163807633748e414e2f3c628062 to your computer and use it in GitHub Desktop.
Save staticfloat/ce81a163807633748e414e2f3c628062 to your computer and use it in GitHub Desktop.
include("util.jl")
# First, do OpenBLAS64 vs. OpenBLAS32 testing (should be about the same)
BLAS.lbt_forward(OpenBLAS_jll.libopenblas_path; clear=true)
BLAS.lbt_forward(OpenBLAS32_jll.libopenblas_path)
config = BLAS.get_config()
@show gemm_test(3000, 64)
@show gemm_test(3000, 32)
# Next, do OpenBLAS64 vs. Accelerate testing
BLAS.lbt_forward(OpenBLAS_jll.libopenblas_path; clear=true)
BLAS.lbt_forward("/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate")
@show gemm_test(2000, 32)
$ julia -i blas.jl
gemm_test(3000, 64) = 0.333748 # <-- OpenBLAS
gemm_test(3000, 32) = 0.311170916 # <-- OpenBLAS32
gemm_test(3000, 32) = 0.299817792 # <-- Accelerate
using OpenBLAS_jll, OpenBLAS32_jll, libblastrampoline_jll, LinearAlgebra, Libdl
get_dgemm_addr(::Val{Int32}) = dlsym(libblastrampoline_jll.libblastrampoline_handle, :dgemm_)
get_dgemm_addr(::Val{Int64}) = dlsym(libblastrampoline_jll.libblastrampoline_handle, :dgemm_64_)
function gemm!(transA::AbstractChar, transB::AbstractChar,
alpha::Union{(Float64), Bool},
A::AbstractVecOrMat{Float64}, B::AbstractVecOrMat{Float64},
beta::Union{(Float64), Bool},
C::AbstractVecOrMat{Float64}, ::Val{BlasInt}) where {BlasInt <: Union{Int32,Int64}}
m = size(A, transA == 'N' ? 1 : 2)
ka = size(A, transA == 'N' ? 2 : 1)
kb = size(B, transB == 'N' ? 1 : 2)
n = size(B, transB == 'N' ? 2 : 1)
if ka != kb || m != size(C,1) || n != size(C,2)
throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
end
fptr = get_dgemm_addr(Val(BlasInt))
ccall(:jl_breakpoint, Cvoid, (Any,), fptr)
ccall(fptr, Cvoid, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}, Clong, Clong),
transA, transB, m, n,
ka, alpha, A, max(1,stride(A,2)),
B, max(1,stride(B,2)), beta, C,
max(1,stride(C,2)), 1, 1)
C
end
function gemm(transA::AbstractChar, transB::AbstractChar, alpha::(Float64), A::AbstractMatrix{Float64}, B::AbstractMatrix{Float64}, ::Val{BlasInt}) where {BlasInt <: Union{Int32,Int64}}
gemm!(transA, transB, alpha, A, B, zero(Float64), similar(B, Float64, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1))), Val(BlasInt))
end
function gemm_test(n::Int, interface::Int)
BlasInt = Int32
if interface == 64
BlasInt = Int64
end
# First, warmup
X = fill(0.1, 200, 200)
gemm('N', 'N', 1.0, X, X, Val(BlasInt))
# Next, actually call it; do so ten times, then return the minimum
X = fill(0.1, n, n)
ts = [@elapsed gemm('N', 'N', 1.0, X, X, Val(BlasInt)) for _ in 1:10]
return minimum(ts)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment