Last active
August 29, 2015 14:03
-
-
Save bjodah/04a95949f7d9747741b3 to your computer and use it in GitHub Desktop.
prototype Dictionary Of Keys based sparse (square) matrix. Implemented in cython.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import (absolute_import, division, | |
print_function, unicode_literals) | |
from sparse import SMat | |
def _print(a): | |
for i in range(a.n): | |
for j in range(a.n): | |
print(str(a[i, j]) + ' ', end='') | |
print('\n', end='') | |
print('\n', end='') | |
def main(n): | |
a = SMat(n) | |
_print(a) | |
a[1,1] = 2.0 | |
_print(a) | |
assert a.npop == 1 | |
#print(a.peek()) | |
b = SMat(n) | |
b[1,1] = 8.0 | |
b[2,2] = 4.0 | |
_print(b) | |
assert b.npop == 2 | |
c = a + b | |
print('bug: ', c.npop) # <--- bug | |
assert c[1,1] == 8+2 | |
assert c[2,2] == 4 | |
assert c[0,0] == 0 | |
_print(c) | |
if __name__ == '__main__': | |
main(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from distutils.core import setup | |
from Cython.Build import cythonize | |
setup( | |
ext_modules = cythonize("sparse.pyx") | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# distutils: language = c++ | |
import cython | |
from libcpp.pair cimport pair | |
from libcpp.vector cimport vector | |
cdef struct keyval: | |
int i, j | |
double x | |
ctypedef keyval Keyval | |
cdef SMat_from_add(SMat l, SMat r): | |
cdef int n = l.n | |
if not n == r.n: | |
raise ValueError("Size mismatch.") | |
cdef int i,j | |
cdef SMat x, res | |
cdef Keyval kv | |
res = SMat(l.n, (l.npop+r.npop)*2/l.n) | |
for x in (l, r): | |
for i in range(n): | |
for j in range(x.data[i].size()): | |
kv = x.data[i][j] | |
res._set(kv.i, kv.j, kv.x + res._get(kv.i, kv.j)) | |
return res | |
cdef class SMat: | |
cdef vector[vector[Keyval]] data | |
cdef readonly int n, npop | |
def peek(self): | |
cdef vector[Keyval] v | |
cdef Keyval kv | |
return [[(kv.i, kv.j, kv.x) for kv in v] for v in self.data] | |
def __cinit__(self, int n, int o=0): | |
cdef int i | |
cdef vector[Keyval] v | |
self.n = n | |
if o == 0: | |
o = int(n**0.5) | |
self.data.assign(n, v) | |
for i in range(n): | |
self.data[i].reserve(o) | |
self.npop = 0 | |
@cython.cdivision(True) | |
cdef int _hash(self, int i, int j): | |
return (i+(self.n-2)*j) % self.n | |
cdef _get(self, int ri, int ci): | |
cdef int i | |
cdef int h = self._hash(ri, ci) | |
for i in range(self.data[h].size()): | |
if ri == self.data[h][i].i and ci == self.data[h][i].j: | |
return self.data[h][i].x | |
return 0.0 | |
def __getitem__(self, pair[int, int] index): | |
return self._get(index.first, index.second) | |
cdef _set(self, int ri, int ci, double value): | |
cdef int i | |
cdef int h = self._hash(ri, ci) | |
cdef Keyval kv | |
for i in range(self.data[h].size()): | |
if ri == self.data[h][i].i and ci == self.data[h][i].j: | |
self.data[h][i].x = value | |
else: # loop depleted | |
kv.i = ri | |
kv.j = ci | |
kv.x = value | |
self.data[h].push_back(kv) | |
self.npop += 1 | |
def __setitem__(self, pair[int, int] index, double value): | |
self._set(index.first, index.second, value) | |
def __add__(left, right): | |
if isinstance(left, SMat) and isinstance(right, SMat): | |
return SMat_from_add(left, right) | |
else: | |
raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment