Skip to content

Instantly share code, notes, and snippets.

@tkelman
Last active August 29, 2015 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tkelman/9175190 to your computer and use it in GitHub Desktop.
Save tkelman/9175190 to your computer and use it in GitHub Desktop.
function spmm{Tv,Ti}(Ain::SparseMatrixCSC{Tv,Ti}, Bin::SparseMatrixCSC{Tv,Ti})
A = Bin.'
B = Ain.'
mA, nA = size(A)
mB, nB = size(B)
nA==mB || throw(DimensionMismatch(""))
colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
colptrB = B.colptr; rowvalB = B.rowval; nzvalB = B.nzval
colptrC = Array(Ti, nB+1)
rowvalC = Array(Ti, 0)
nzvalC = Array(Tv, 0)
@inbounds begin
ip = 1
xb = zeros(Ti, mA)
# first pass to determine the number of nonzeros in C
for i in 1:nB
colptrC[i] = ip
for jp in colptrB[i]:(colptrB[i+1] - 1)
j = rowvalB[jp]
for kp in colptrA[j]:(colptrA[j+1] - 1)
k = rowvalA[kp]
if xb[k] != i
ip += 1
xb[k] = i
end
end
end
end
colptrC[nB+1] = ip
resize!(rowvalC, ip)
resize!(nzvalC, ip)
ip = 1
xb = zeros(Ti, mA)
x = zeros(Tv, mA)
# second pass to calculate the nonzero values of C
for i in 1:nB
for jp in colptrB[i]:(colptrB[i+1] - 1)
nzB = nzvalB[jp]
j = rowvalB[jp]
for kp in colptrA[j]:(colptrA[j+1] - 1)
nzC = nzvalA[kp] * nzB
k = rowvalA[kp]
if xb[k] != i
rowvalC[ip] = k
ip += 1
xb[k] = i
x[k] = nzC
else
x[k] += nzC
end
end
end
for vp in colptrC[i]:(ip - 1)
nzvalC[vp] = x[rowvalC[vp]]
end
end
end
# The Gustavson algorithm does not guarantee the product to have sorted row indices.
Cunsorted = SparseMatrixCSC(mA, nB, colptrC, rowvalC, nzvalC)
Ct = Cunsorted.'
#Ctt = Base.SparseMatrix.transpose!(Ct, SparseMatrixCSC(mA, nB, colptrC, rowvalC, nzvalC))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment