Skip to content

Instantly share code, notes, and snippets.

@gwerbin
Created September 26, 2019 02:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gwerbin/8d089850143f2c7764f30ed1d2c2881c to your computer and use it in GitHub Desktop.
Save gwerbin/8d089850143f2c7764f30ed1d2c2881c to your computer and use it in GitHub Desktop.
Get the column indices out of a sparse array
import numpy as np
from scipy import sparse
x = np.array([
[1, 0, 1],
[0, 0, 1],
[1, 1, 0],
[0, 1, 1]
])
def get_col_indices_csr(mat):
if not sparse.isspmatrix_csr(mat):
raise ValueError('Only applicable for CSR matrices')
indptr = mat.indptr
indices = mat.indices
return [indices[indptr[i] : indptr[i+1]] for i in range(len(indptr)-1)]
def get_col_indices_lil(mat):
if not sparse.isspmatrix_lil(mat):
raise ValueError('Only applicable for LIL matrices')
return list(mat.rows)
expected = [
[0, 2],
[2],
[0, 1],
[1, 2]
]
expected = [np.array(l, dtype=np.int32) for l in expected]
result = get_col_indices_csr(sparse.csr_matrix(x))
for i in range(len(expected)):
np.testing.assert_array_equal(expected[i], result[i])
result = get_col_indices_lil(sparse.lil_matrix(x))
for i in range(len(expected)):
np.testing.assert_array_equal(expected[i], result[i])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment