Skip to content

Instantly share code, notes, and snippets.

@wush978
Created May 18, 2020 17:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wush978/3a0e02b64c554546868402a517cc3c92 to your computer and use it in GitHub Desktop.
Save wush978/3a0e02b64c554546868402a517cc3c92 to your computer and use it in GitHub Desktop.
%%cython --annotate --cplus --compile-args=-fopenmp --link-args=-fopenmp
cimport cython
from cython.parallel cimport prange, parallel, threadid
import numpy as np
cimport numpy as np
import scipy
cimport openmp
#cdef extern from "<algorithm>" namespace "std":
# cdef void sort[RI](RI first, RI last)
cdef extern from "<parallel/algorithm>" namespace "__gnu_parallel":
cdef void sort[RI](RI first, RI last) except +
cdef void unique_copy[II, OI](II first, II last, OI d_first) except +
from libcpp.vector cimport vector
from libcpp.utility cimport pair
ctypedef pair[np.int64_t,np.int64_t] Index
cdef void __csr_to_csc(
np.int64_t *src_indptr,
np.int64_t *src_indices,
np.int64_t *dst_indptr,
np.int64_t *dst_indices,
size_t nrow,
size_t ncol,
size_t nnz,
):
cdef vector[Index] index
cdef vector[vector[np.int64_t]] buffer = vector[vector[np.int64_t]](openmp.omp_get_max_threads())
cdef size_t i, j, nthread
with nogil:
index.resize(nnz)
for i in prange(nrow):
for j in range(src_indptr[i],src_indptr[i+1]):
index[j].second = i # row
index[j].first = src_indices[j] # col
sort[vector[Index].iterator](index.begin(), index.end())
with nogil, parallel():
buffer[threadid()].resize(ncol)
for i in prange(nnz):
buffer[threadid()][index[i].first] += 1
dst_indices[i] = index[i].second
for i in prange(ncol):
for j in range(buffer.size()):
if buffer[j].size() > 0:
dst_indptr[i+1] += buffer[j][i]
for i in range(ncol):
dst_indptr[i+1] = dst_indptr[i+1] + dst_indptr[i]
cdef np.int64_t* getp(np.ndarray[np.int64_t, ndim = 1] arr):
return &arr[0]
def csr_to_csc(m):
if not type(m) is scipy.sparse.csr.csr_matrix:
raise RuntimeError("m is not a csr_matrix")
if not m.indptr.dtype == np.int64:
raise RuntimeError("The indptr is not int64")
assert(m.indices.dtype == np.int64)
if not np.all(m.data == 1):
raise RuntimeError("The data is not all 1")
dst_indptr = np.zeros(m.shape[1] + 1, dtype = np.int64)
dst_indices = np.zeros(len(m.indices), dtype = np.int64)
__csr_to_csc(
getp(m.indptr),
getp(m.indices),
getp(dst_indptr),
getp(dst_indices),
m.shape[0],
m.shape[1],
len(m.indices),
)
return scipy.sparse.csc_matrix((m.data, dst_indices, dst_indptr), shape = m.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment