Created
September 26, 2019 02:46
-
-
Save gwerbin/8d089850143f2c7764f30ed1d2c2881c to your computer and use it in GitHub Desktop.
Get the column indices out of a sparse array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy import sparse | |
x = np.array([ | |
[1, 0, 1], | |
[0, 0, 1], | |
[1, 1, 0], | |
[0, 1, 1] | |
]) | |
def get_col_indices_csr(mat): | |
if not sparse.isspmatrix_csr(mat): | |
raise ValueError('Only applicable for CSR matrices') | |
indptr = mat.indptr | |
indices = mat.indices | |
return [indices[indptr[i] : indptr[i+1]] for i in range(len(indptr)-1)] | |
def get_col_indices_lil(mat): | |
if not sparse.isspmatrix_lil(mat): | |
raise ValueError('Only applicable for LIL matrices') | |
return list(mat.rows) | |
expected = [ | |
[0, 2], | |
[2], | |
[0, 1], | |
[1, 2] | |
] | |
expected = [np.array(l, dtype=np.int32) for l in expected] | |
result = get_col_indices_csr(sparse.csr_matrix(x)) | |
for i in range(len(expected)): | |
np.testing.assert_array_equal(expected[i], result[i]) | |
result = get_col_indices_lil(sparse.lil_matrix(x)) | |
for i in range(len(expected)): | |
np.testing.assert_array_equal(expected[i], result[i]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment