Skip to content

Instantly share code, notes, and snippets.

@bjodah
Last active August 29, 2015 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bjodah/04a95949f7d9747741b3 to your computer and use it in GitHub Desktop.
Save bjodah/04a95949f7d9747741b3 to your computer and use it in GitHub Desktop.
prototype Dictionary Of Keys based sparse (square) matrix. Implemented in cython.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import (absolute_import, division,
print_function, unicode_literals)
from sparse import SMat
def _print(a):
for i in range(a.n):
for j in range(a.n):
print(str(a[i, j]) + ' ', end='')
print('\n', end='')
print('\n', end='')
def main(n):
a = SMat(n)
_print(a)
a[1,1] = 2.0
_print(a)
assert a.npop == 1
#print(a.peek())
b = SMat(n)
b[1,1] = 8.0
b[2,2] = 4.0
_print(b)
assert b.npop == 2
c = a + b
print('bug: ', c.npop) # <--- bug
assert c[1,1] == 8+2
assert c[2,2] == 4
assert c[0,0] == 0
_print(c)
if __name__ == '__main__':
main(3)
from distutils.core import setup
from Cython.Build import cythonize
setup(
ext_modules = cythonize("sparse.pyx")
)
# -*- coding: utf-8 -*-
# distutils: language = c++
import cython
from libcpp.pair cimport pair
from libcpp.vector cimport vector
cdef struct keyval:
int i, j
double x
ctypedef keyval Keyval
cdef SMat_from_add(SMat l, SMat r):
cdef int n = l.n
if not n == r.n:
raise ValueError("Size mismatch.")
cdef int i,j
cdef SMat x, res
cdef Keyval kv
res = SMat(l.n, (l.npop+r.npop)*2/l.n)
for x in (l, r):
for i in range(n):
for j in range(x.data[i].size()):
kv = x.data[i][j]
res._set(kv.i, kv.j, kv.x + res._get(kv.i, kv.j))
return res
cdef class SMat:
cdef vector[vector[Keyval]] data
cdef readonly int n, npop
def peek(self):
cdef vector[Keyval] v
cdef Keyval kv
return [[(kv.i, kv.j, kv.x) for kv in v] for v in self.data]
def __cinit__(self, int n, int o=0):
cdef int i
cdef vector[Keyval] v
self.n = n
if o == 0:
o = int(n**0.5)
self.data.assign(n, v)
for i in range(n):
self.data[i].reserve(o)
self.npop = 0
@cython.cdivision(True)
cdef int _hash(self, int i, int j):
return (i+(self.n-2)*j) % self.n
cdef _get(self, int ri, int ci):
cdef int i
cdef int h = self._hash(ri, ci)
for i in range(self.data[h].size()):
if ri == self.data[h][i].i and ci == self.data[h][i].j:
return self.data[h][i].x
return 0.0
def __getitem__(self, pair[int, int] index):
return self._get(index.first, index.second)
cdef _set(self, int ri, int ci, double value):
cdef int i
cdef int h = self._hash(ri, ci)
cdef Keyval kv
for i in range(self.data[h].size()):
if ri == self.data[h][i].i and ci == self.data[h][i].j:
self.data[h][i].x = value
else: # loop depleted
kv.i = ri
kv.j = ci
kv.x = value
self.data[h].push_back(kv)
self.npop += 1
def __setitem__(self, pair[int, int] index, double value):
self._set(index.first, index.second, value)
def __add__(left, right):
if isinstance(left, SMat) and isinstance(right, SMat):
return SMat_from_add(left, right)
else:
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment