Skip to content

Instantly share code, notes, and snippets.

@jebej
Last active June 3, 2017 12:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jebej/a5932b0df4cd11dc39218531d2fb3e34 to your computer and use it in GitHub Desktop.
Save jebej/a5932b0df4cd11dc39218531d2fb3e34 to your computer and use it in GitHub Desktop.
Symmetric Sparse Implementation Benchmarks
module SymmetricSparseTests
using BenchmarkTools
import Base: Symmetric, *, A_mul_B!, LinAlg.checksquare
function Symmetric(A::SparseMatrixCSC, uplo::Symbol=:U)
checksquare(A)
Symmetric{eltype(A), typeof(A)}(A, Base.LinAlg.char_uplo(uplo)) # preserve A
end
(*)(A::Symmetric{TA,SparseMatrixCSC{TA,S}}, x::StridedVecOrMat{Tx}) where {TA,S,Tx} = A_mul_B(A, x)
function A_mul_B!(α::Number, A::Symmetric{TA,SparseMatrixCSC{TA,S}}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) where {TA,S}
A.data.n == size(B, 1) || throw(DimensionMismatch())
A.data.m == size(C, 1) || throw(DimensionMismatch())
A.uplo == 'U' ? A_mul_B_U_kernel!(α, A, B, β, C) : A_mul_B_L_kernel!(α, A, B, β, C)
end
function A_mul_B_nocheck!(α::Number, A::Symmetric{TA,SparseMatrixCSC{TA,S}}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) where {TA,S}
A.data.n == size(B, 1) || throw(DimensionMismatch())
A.data.m == size(C, 1) || throw(DimensionMismatch())
A_mul_B_nocheck_kernel!(α, A, B, β, C)
end
function A_mul_B(A::Symmetric{TA,SparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx}
T = promote_type(TA, Tx)
A_mul_B!(one(T), A, x, zero(T), similar(x, T, A.data.n))
end
function A_mul_B(A::Symmetric{TA,SparseMatrixCSC{TA,S}}, B::StridedMatrix{Tx}) where {TA,S,Tx}
T = promote_type(TA, Tx)
A_mul_B!(one(T), A, B, zero(T), similar(B, T, (A.data.n, size(B, 2))))
end
function A_mul_B_U_kernel!(α::Number, A::Symmetric{TA,SparseMatrixCSC{TA,S}}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) where {TA,S}
colptr = A.data.colptr
rowval = A.data.rowval
nzval = A.data.nzval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for k = 1 : size(C, 2)
@inbounds for col = 1 : A.data.n
αxj = α * B[col, k]
tmp = TA(0)
@inbounds for j = colptr[col] : (colptr[col + 1] - 1)
row = rowval[j]
row > col && break # assume indices are sorted
a = nzval[j]
C[row, k] += a * αxj
row == col || (tmp += a * B[row, k])
end
C[col, k] += tmp
end
end
C
end
function A_mul_B_L_kernel!(α::Number, A::Symmetric{TA,SparseMatrixCSC{TA,S}}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) where {TA,S}
colptr = A.data.colptr
rowval = A.data.rowval
nzval = A.data.nzval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for k = 1 : size(C, 2)
@inbounds for col = 1 : A.data.n
αxj = α * B[col, k]
tmp = TA(0)
@inbounds for j = (colptr[col + 1] - 1) : -1 : colptr[col]
row = rowval[j]
row < col && break
a = nzval[j]
C[row, k] += a * αxj
row == col || (tmp += a * B[row, k])
end
C[col, k] += tmp
end
end
C
end
function A_mul_B_nocheck_kernel!(α::Number, A::Symmetric{TA,SparseMatrixCSC{TA,S}}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) where {TA,S}
colptr = A.data.colptr
rowval = A.data.rowval
nzval = A.data.nzval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for k = 1 : size(C, 2)
@inbounds for col = 1 : A.data.n
αxj = α * B[col, k]
tmp = TA(0)
@inbounds for j = colptr[col] : (colptr[col + 1] - 1)
row = rowval[j]
#row > col && break # assume indices are sorted
a = nzval[j]
C[row, k] += a * αxj
row == col || (tmp += a * B[row, k])
end
C[col, k] += tmp
end
end
C
end
function runtests(N=10,p=0.32)
A = sprand(N,N,p)
A = sparse(full(Symmetric(full(A))))
B = rand(N,N)
C1 = similar(B); C2 = similar(B); C3 = similar(B); C4 = similar(B)
bres = @benchmark A_mul_B!(1.0, $A, $B, 0.0, $C1)
println("Normal sparse: $bres")
A = Symmetric(A)
bres = @benchmark A_mul_B!(1.0, $A, $B, 0.0, $C2)
println("Symmetric sparse: $bres")
A = Symmetric(triu(A.data))
bres = @benchmark A_mul_B!(1.0, $A, $B, 0.0, $C3)
println("Symmetric sparse (triu only): $bres")
bres = @benchmark SymmetricSparseTests.A_mul_B_nocheck!(1.0, $A, $B, 0.0, $C4)
println("Symmetric sparse (triu only) optimized: $bres")
C1 == C2 == C3 == C4
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment