-
-
Save sglyon/7790b35049342be7cc688549ad69285c to your computer and use it in GitHub Desktop.
basis matrices in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is basis matrices in python | |
""" | |
import numpy as np | |
from numba import njit | |
import quantecon as qe | |
import scipy.sparse as spa | |
import scipy.sparse.linalg as spla | |
import scipy.linalg as la | |
from functools import reduce | |
@njit(cache=True) | |
def row_kron_dense_csr(m, na, nb, A, Bval, Bind, Bptr, outval, outind, outptr): | |
outptr[0] = 0 | |
ix = 0 | |
for row in range(m): | |
for col_a in range(na): | |
_a = A[row, col_a] | |
for _ib in range(Bptr[row], Bptr[row + 1]): | |
col_b = Bind[_ib] | |
_b = Bval[_ib] | |
outval[ix] = _a * _b | |
outind[ix] = col_a * nb + col_b | |
ix += 1 | |
outptr[row + 1] = ix | |
return outval, outind, outptr | |
@njit(cache=True) | |
def row_kron_dense_dense(m, na, nb, A, B, out): | |
for row in range(m): | |
ix = 0 | |
for col_a in range(na): | |
_a = A[row, col_a] | |
for col_b in range(nb): | |
out[row, ix] = _a * B[row, col_b] | |
ix += 1 | |
return out | |
@njit(cache=True) | |
def row_kron_csr_dense(m, na, nb, Aval, Aind, Aptr, B, outval, outind, outptr): | |
outptr[0] = 0 | |
ix = 0 | |
for row in range(m): | |
for _ia in range(Aptr[row], Aptr[row + 1]): | |
col_a = Aind[_ia] | |
_a = Aval[_ia] | |
for col_b in range(nb): | |
outval[ix] = _a * B[row, col_b] | |
outind[ix] = col_a * nb + col_b | |
ix += 1 | |
outptr[row + 1] = ix | |
return outval, outind, outptr | |
@njit(cache=True) | |
def row_kron_csr_csr(m, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr, outval, | |
outind, outptr): | |
outptr[0] = 0 | |
ix = 0 | |
for row in range(m): | |
for _ia in range(Aptr[row], Aptr[row + 1]): | |
col_a = Aind[_ia] | |
_a = Aval[_ia] | |
for _ib in range(Bptr[row], Bptr[row + 1]): | |
col_b = Bind[_ib] | |
_b = Bval[_ib] | |
outval[ix] = _a * _b | |
outind[ix] = col_a * nb + col_b | |
ix += 1 | |
outptr[row + 1] = ix | |
return outval, outind, outptr | |
@njit(cache=True) | |
def nnz_row_one_csr(nrow, m_dense, csr_ptr): | |
out = 0 | |
for row in range(nrow): | |
out += m_dense * (csr_ptr[row + 1] - csr_ptr[row]) | |
return out | |
@njit(cache=True) | |
def nnz_row_csr_csr(m, Aptr, Bptr): | |
out = 0 | |
for row in range(m): | |
out += (Aptr[row + 1] - Aptr[row]) * (Bptr[row + 1] - Bptr[row]) | |
return out | |
def row_kron(A, B): | |
ma = A.shape[0] | |
mb = B.shape[0] | |
na = A.shape[1] | |
nb = B.shape[1] | |
assert ma == mb | |
if spa.issparse(A): | |
if not spa.isspmatrix_csr(A): | |
As = A.tocsr() | |
else: | |
As = A | |
out = spa.csr_matrix((ma, na * nb)) | |
Aval = As.data | |
Aind = As.indices | |
Aptr = As.indptr | |
outptr = np.zeros(ma + 1, dtype="int32") | |
if spa.issparse(B): | |
if not spa.isspmatrix_csr(B): | |
Bs = B.tocsr() | |
else: | |
Bs = B | |
Bval = Bs.data | |
Bind = Bs.indices | |
Bptr = Bs.indptr | |
nnz = nnz_row_csr_csr(ma, Aptr, Bptr) | |
outval = np.zeros(nnz) | |
outind = np.zeros(nnz, dtype="int32") | |
row_kron_csr_csr(ma, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr, | |
outval, outind, outptr) | |
else: | |
nnz = nnz_row_one_csr(ma, nb, Aptr) | |
outval = np.zeros(nnz) | |
outind = np.zeros(nnz, dtype="int32") | |
row_kron_csr_dense(ma, na, nb, Aval, Aind, Aptr, B, outval, outind, | |
outptr) | |
out.indptr = outptr | |
out.data = outval | |
out.indices = outind | |
return out | |
else: | |
if spa.issparse(B): | |
Bs = B.tocsr() | |
Bval = Bs.data | |
Bind = Bs.indices | |
Bptr = Bs.indptr | |
nnz = nnz_row_one_csr(ma, na, Bptr) | |
outptr = np.zeros(ma + 1, dtype="int32") | |
outval = np.zeros(nnz) | |
outind = np.zeros(nnz, dtype="int32") | |
row_kron_dense_csr(ma, na, nb, A, Bval, Bind, Bptr, outval, outind, | |
outptr) | |
return spa.csr_matrix( | |
(outval, outind, outptr), shape=(ma, na * nb), copy=False) | |
else: | |
out = np.empty((ma, na * nb), dtype=A.dtype) | |
return row_kron_dense_dense(ma, na, nb, A, B, out) | |
@njit(cache=True) | |
def full_kron_dense_csr(ma, mb, na, nb, A, Bval, Bind, Bptr, outval, outind, outptr): | |
outptr[0] = 0 | |
ix = 0 | |
for ia in range(ma): | |
for ib in range(mb): | |
row = ia*mb + ib | |
for ja in range(na): | |
aij = A[ia, ja] | |
col_offset = ja*nb | |
for _ind_b in range(Bptr[ib], Bptr[ib+1]): | |
jb = Bind[_ind_b] | |
bij = Bval[_ind_b] | |
outval[ix] = aij*bij | |
outind[ix] = col_offset + jb | |
ix += 1 | |
outptr[row + 1] = ix | |
return outval, outind, outptr | |
@njit(cache=True) | |
def full_kron_dense_dense(ma, mb, na, nb, A, B, out): | |
for ia in range(ma): | |
for ib in range(mb): | |
row = ia*mb + ib | |
for ja in range(na): | |
aij = A[ia, ja] | |
col_offset = ja*nb | |
for jb in range(nb): | |
out[row, col_offset + jb] = aij*B[ib, jb] | |
return out | |
@njit(cache=True) | |
def full_kron_csr_dense(ma, mb, na, nb, Aval, Aind, Aptr, B, outval, outind, outptr): | |
outptr[0] = 0 | |
ix = 0 | |
for ia in range(ma): | |
for ib in range(mb): | |
row = ia*mb + ib | |
for _ind_a in range(Aptr[ia], Aptr[ia+1]): | |
ja = Aind[_ind_a] | |
aij = Aval[_ind_a] | |
col_offset = ja*nb | |
for jb in range(nb): | |
outval[ix] = aij*B[ib, jb] | |
outind[ix] = col_offset + jb | |
ix += 1 | |
outptr[row+1] = ix | |
return outval, outind, outptr | |
@njit(cache=True) | |
def full_kron_csr_csr(ma, mb, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr, outval, | |
outind, outptr): | |
outptr[0] = 0 | |
ix = 0 | |
for ia in range(ma): | |
for ib in range(mb): | |
row = ia*mb + ib | |
for _ind_a in range(Aptr[ia], Aptr[ia+1]): | |
ja = Aind[_ind_a] | |
aij = Aval[_ind_a] | |
col_offset = ja*nb | |
for _ind_b in range(Bptr[ib], Bptr[ib+1]): | |
jb = Bind[_ind_b] | |
bij = Bval[_ind_b] | |
outval[ix] = aij*bij | |
outind[ix] = col_offset + jb | |
ix += 1 | |
outptr[row + 1] = ix | |
return outval, outind, outptr | |
def full_kron(A, B): | |
ma = A.shape[0] | |
mb = B.shape[0] | |
na = A.shape[1] | |
nb = B.shape[1] | |
if spa.issparse(A): | |
if not spa.isspmatrix_csr(A): | |
As = A.tocsr() | |
else: | |
As = A | |
out = spa.csr_matrix((ma * mb, na * nb)) | |
Aval = As.data | |
Aind = As.indices | |
Aptr = As.indptr | |
outptr = np.zeros(ma*mb + 1, dtype="int32") | |
if spa.issparse(B): | |
if not spa.isspmatrix_csr(B): | |
Bs = B.tocsr() | |
else: | |
Bs = B | |
Bval = Bs.data | |
Bind = Bs.indices | |
Bptr = Bs.indptr | |
nnz = Bval.size * Aval.size | |
outval = np.zeros(nnz, dtype=A.dtype) | |
outind = np.zeros(nnz, dtype="int32") | |
full_kron_csr_csr(ma, mb, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr, | |
outval, outind, outptr) | |
else: | |
nnz = Aval.size * B.size | |
outval = np.zeros(nnz) | |
outind = np.zeros(nnz, dtype="int32") | |
full_kron_csr_dense(ma, mb, na, nb, Aval, Aind, Aptr, B, outval, outind, | |
outptr) | |
out.indptr = outptr | |
out.data = outval | |
out.indices = outind | |
return out | |
else: | |
if spa.issparse(B): | |
if not spa.isspmatrix_csr(B): | |
Bs = B.tocsr() | |
else: | |
Bs = B | |
Bval = Bs.data | |
Bind = Bs.indices | |
Bptr = Bs.indptr | |
nnz = A.size * Bval.size | |
outptr = np.zeros(ma*mb + 1, dtype="int32") | |
outval = np.zeros(nnz, dtype=A.dtype) | |
outind = np.zeros(nnz, dtype="int32") | |
full_kron_dense_csr(ma, mb, na, nb, A, Bval, Bind, Bptr, outval, outind, | |
outptr) | |
out = spa.csr_matrix((ma * mb, na * nb)) | |
out.indptr = outptr | |
out.indices = outind | |
out.data = outval | |
return out | |
else: | |
out = np.empty((ma*mb, na * nb), dtype=A.dtype) | |
return full_kron_dense_dense(ma, mb, na, nb, A, B, out) | |
@njit(cache=True) | |
def lookup(table, x, p): | |
n = table.size | |
m = x.size | |
out = np.empty((m, ), np.int32) | |
# lower enbound adjustment | |
numfirst = 1 | |
while numfirst < n and table[numfirst] == table[0]: | |
numfirst += 1 | |
# Upper endpoint adjustment | |
if p >= 2: | |
n -= 1 | |
for i in range(n - 1, 0, -1): | |
if table[i] == table[-1]: | |
n -= 1 | |
else: | |
break | |
n1 = n - 1 | |
n2 = n - 2 | |
# handle 1-value lists separately | |
if n - numfirst < 1: | |
if p == 1 or p == 3: | |
for i in range(m): | |
out[i] = numfirst - 1 | |
else: | |
for i in range(m): | |
if table[0] <= x[i]: | |
out[i] = numfirst - 1 | |
else: | |
out[i] = -1 | |
return out | |
jlo = 0 | |
for i in range(m): | |
inc = 1 | |
xi = x[i] | |
if xi >= table[jlo]: | |
jhi = jlo + 1 | |
while xi >= table[jhi]: | |
jlo = jhi | |
jhi += inc | |
if jhi >= n: | |
break | |
else: | |
inc += inc | |
else: | |
jhi = jlo | |
jlo -= 1 | |
while xi < table[jlo]: | |
jhi = jlo | |
jlo -= inc | |
if jlo < 0: | |
jlo = -1 | |
break | |
else: | |
inc += inc | |
while jhi - jlo > 1: | |
j = (jhi + jlo) // 2 | |
# if j >= n: | |
# j = n-1 | |
if j < n and xi >= table[j]: | |
jlo = j | |
else: | |
jhi = j | |
out[i] = jlo | |
if jlo < 0: | |
jlo = 0 | |
if p == 1 or p == 3: | |
out[i] = numfirst - 1 | |
if jlo == n1: | |
jlo = n2 | |
return out | |
class LinearBasis(object): | |
def __call__(self, x, orders=0): | |
x = np.atleast_1d(np.asarray(x)) | |
out = self.basis_matrix(x, orders) | |
if isinstance(orders, (int, np.int32, np.int64)): | |
return out[0] | |
else: | |
return out | |
class Spline(LinearBasis): | |
def __init__(self, breaks, evennum=0, k=3): | |
if k <= 0: | |
raise ValueError("spline order must be positive") | |
# check to see if breaks is a list or tuple | |
if isinstance(breaks, (list, tuple)): | |
if len(breaks) == 3: | |
# assume it is of the form (lb, ub, n) | |
breaks = np.linspace(*breaks) | |
else: | |
# assume the breaks are given directly. | |
breaks = np.asarray(breaks) | |
if isinstance(breaks, np.ndarray): | |
if not all(breaks == np.sort(breaks)): | |
raise ValueError("Breaks must be sorted") | |
if breaks.size < 2: | |
raise ValueError("must have at least 2 breakpoints") | |
if evennum == 0: | |
if breaks.size == 2: | |
evennum = 2 | |
else: | |
if breaks.size == 2: | |
breaks = np.linspace(breaks[0], breaks[2], evennum) | |
else: | |
m = "Breakpoint sequence must contain 2 values when" | |
raise ValueError(m + " evennum > 0") | |
self.breaks = breaks | |
self.evennum = evennum | |
self.k = k | |
self.n = len(breaks) + k - 1 | |
else: | |
raise ValueError("Couldn't interpret input form of breaks") | |
@property | |
def nodes(self): | |
breaks, k = self.breaks, self.k | |
a = breaks[0] # 20 | |
b = breaks[-1] # 21 | |
n = self.n | |
x = np.cumsum(np.concatenate((np.full(k, a), breaks, | |
np.full(k, b)))) # 23 | |
x = (x[k:n + k] - x[:n]) / k # 24 | |
x[0] = a # 25 | |
x[-1] = b # 26 | |
return x | |
def derivative_op(self, order=1): | |
raise NotImplementedError() | |
def basis_matrix(self, x, order=0): | |
breaks, k = self.breaks, self.k | |
order = np.atleast_1d(order) | |
# error handling | |
if any(order >= k): | |
raise ValueError("Order of diff must be < k") | |
if x.ndim > 1: | |
if any(np.asarray(x.shape[1:]) > 1): | |
raise ValueError("x must be a vector") | |
else: | |
# flatten | |
x = x.reshape(-1) | |
m = len(x) | |
minorder = order.min() | |
# Augment the breakpoint sequence 57-59 | |
n = self.n | |
a = breaks[0] | |
b = breaks[-1] | |
augbreaks = np.concatenate((np.full(k - minorder, a), breaks, np.full( | |
k - minorder, b))) | |
ind = lookup(augbreaks, x, 3) # 69 | |
bas = np.zeros((k - minorder + 1, m)) | |
bas[0, :] = 1.0 | |
B = [] | |
if order.max() > 0: | |
D = self.derivative_op(order.max()) | |
if minorder < 0: | |
I = self.derivative_op(minorder) | |
for j in range(1, k - minorder + 1): | |
for jj in range(j, 0, -1): | |
b0 = augbreaks[ind + jj - j] | |
b1 = augbreaks[ind + jj] | |
temp = bas[jj - 1, :] / (b1 - b0) | |
bas[jj, :] = (x - b0) * temp + bas[jj, :] | |
bas[jj - 1, :] = (b1 - x) * temp | |
# bas now contains the `j` order spline basis | |
ii = np.argwhere(order == k - j) | |
if len(ii) > 0: | |
ii = ii[0][0] | |
# put the values into appropriate rows/columns of sparse | |
# matrix | |
row_ind_ptr = np.arange( | |
0, m * (k + 1) + 1, k + 1, dtype="int32") | |
colval = np.arange(-k + order[ii], 1, | |
dtype="int32") - (order[ii] - minorder) | |
colval = (colval[None, :] + ind[:, None]).ravel(order="C") | |
val = np.ravel(bas[:k - order[ii] + 1, :], order="F") | |
B_ii = spa.csr_matrix((m, n - order[ii])) | |
B_ii.indices = colval | |
B_ii.data = val | |
B_ii.indptr = row_ind_ptr | |
if order[ii] > 0: | |
B.append(B_ii.dot(D[order[ii]])) | |
elif order[ii] < 0: | |
B.append(B_ii.dot(I[order[ii]])) | |
else: | |
B.append(B_ii) | |
return B | |
class Lin(LinearBasis): | |
def __init__(self, breaks, evennum=0): | |
# passed array directly. Make sure they are sorted | |
if isinstance(breaks, np.ndarray): | |
if not all(breaks == np.sort(breaks)): | |
raise ValueError("breaks must be sorted") | |
self.breaks = breaks | |
# could be of form (lb, ub, n) | |
if isinstance(breaks, tuple) and len(breaks) == 3: | |
self.breaks = np.linspace(*breaks) | |
self.evennum = evennum | |
@property | |
def nodes(self): | |
return self.breaks | |
def basis_matrix(self, x, orders=0): | |
m = x.shape[0] | |
n = self.breaks.size | |
if not isinstance(orders, (int, np.int32, np.int64)): | |
raise NotImplementedError("Only order=0 for Lin is implemented") | |
if self.evennum != 0: | |
raise NotImplementedError("I haven't done this bit yet") | |
ind = lookup(self.breaks, x, 3) | |
z = np.empty(m) | |
for i in range(m): | |
z[i] = ((x[i] - self.breaks[ind[i]]) / | |
(self.breaks[ind[i] + 1] - self.breaks[ind[i]])) | |
row_ind_ptr = np.arange(0, 2 * m + 1, 2, dtype="int32") | |
colval = np.empty((2 * m,), dtype="int32") | |
colval[::2] = ind | |
colval[1::2] = ind + 1 | |
val = np.empty((2 * m,), dtype=x.dtype) | |
val[::2] = 1 - z | |
val[1::2] = z | |
out = spa.csr_matrix((m, n)) | |
out.indices = colval | |
out.data = val | |
out.indptr = row_ind_ptr | |
return [out] | |
class Basis(object): | |
def __init__(self, *params): | |
for p in params: | |
if not isinstance(p, LinearBasis): | |
msg = "Can only construct a basis using instances of " | |
raise ValueError(msg + "LinearBasis") | |
self.params = params | |
self._nodes = None | |
self.n = len(params) | |
def _b_mat_tensor(self): | |
pass | |
def _basis_mat_direct(self, x, orders): | |
out = [] | |
for col in range(orders.shape[1]): | |
Bs = {} | |
needs = np.unique(orders[:, col]) | |
for o in needs: | |
Bs[o] = self.params[col](x[:, col], orders=o) | |
out.append(Bs) | |
return out | |
def _basis_mat_tensor(self, x, orders): | |
out = [] | |
for col in range(orders.shape[1]): | |
Bs = {} | |
needs = np.unique(orders[:, col]) | |
for o in needs: | |
Bs[0] = self.params[col](x[col], orders=o) | |
out.append(Bs) | |
return out | |
def basis_matrix(self, x, orders=0, format="direct"): | |
orders = self._check_order(orders) | |
if format == "direct": | |
if not isinstance(x, np.ndarray): | |
msg = "x must be a numpy array for evaluation in direct form" | |
raise ValueError(x) | |
if x.ndim == 1: | |
if x.size == self.n: | |
x = x.reshape((1, self.n)) | |
else: | |
msg = f"A 1d array was passed, must be shape (1, {self.n})" | |
raise ValueError(msg) | |
if x.ndim == 2 and x.shape[1] != self.n: | |
msg = f"A 2d array was passed, have {self.n} columns" | |
raise ValueError(msg) | |
Bs = self._basis_mat_direct(x, orders) | |
return BasisMatrix(Bs, orders, "direct") | |
if format == "expanded": | |
return self.basis_matrix(x, orders, format="direct").to_expanded() | |
if format == "tensor": | |
if len(x) != self.n: | |
msg = f"For tensor format and {self.n} dim basis, must pass " | |
raise ValueError(msg + f"{self.n} arrays") | |
Bs = self._basis_mat_tensor(x, orders) | |
return BasisMatrix(Bs, orders, "tensor") | |
msg = "Format must be one of ['tensor', 'direct', 'expanded']." | |
raise ValueError(msg + f"Found {format}") | |
@property | |
def nodes(self): | |
if self._nodes is None: | |
dim_nodes = [x.nodes for x in self.params] | |
grid = qe.cartesian(dim_nodes) | |
self._nodes = (grid, dim_nodes) | |
return self._nodes | |
def _check_order(self, order): | |
order = np.asarray(order) | |
if order.ndim == 0: # scalar case | |
return np.full(shape=(1, self.n), fill_value=order) | |
if order.ndim == 1: # vector case | |
if order.size != self.n: | |
msg = f"Expected {n} orders, found ${order.size}" | |
raise ValueError(msg) | |
else: | |
return order.reshape((1, self.n)) | |
elif order.ndim == 2: | |
if order.size % self.n != 0: | |
msg = f"Order must have a multiple of {self.n} elements" | |
raise ValueError(msg) | |
else: | |
return order | |
else: | |
msg = f"order of dimension {order.ndim}, can only handle ndim <= 2" | |
raise ValueError(msg) | |
class BasisMatrix(object): | |
def __init__(self, Bs, orders, format): | |
if not isinstance(orders, np.ndarray): | |
msg = "orders must be a numpy array" | |
raise ValueError(msg) | |
if orders.ndim != 2: | |
msg = "orders must have 2 dimensions" | |
raise ValueError(msg) | |
if format in ["direct", "tensor"]: | |
if not isinstance(Bs, (tuple, list)): | |
if isinstance((Bs[0], dict)): | |
msg = "When using direct or tensor form, pass Bs as a list" | |
raise ValueError(msg + " of dict") | |
self.Bs = np.asarray(Bs) | |
self.orders = orders | |
# if orders.shape[1] != | |
self.format = format | |
def to_direct(self): | |
if self.format == "direct": | |
return self | |
elif self.format == "tensor": | |
raw_inds = [] | |
for i in range(len(self.Bs)): | |
n = self.Bs[i].values().__iter__().__next__().shape[0] | |
raw_inds.append(np.arange(n, dtype="int32")) | |
inds = qe.cartesian(raw_inds) | |
out = [] | |
for i in range(len(self.Bs)): | |
Bs = {} | |
for (key, val) in self.Bs[i].items(): | |
Bs[key] = val[inds[:, i], :] | |
out.append(Bs) | |
return BasisMatrix(out, self.orders, "direct") | |
else: | |
msg = "Cannot convert from expanded to direct form" | |
raise ValueError(msg) | |
def to_expanded(self): | |
if self.format == "direct" or self.format == "tensor": | |
func = row_kron if self.format == "direct" else full_kron | |
out = [] | |
for row in range(self.orders.shape[0]): | |
ords = self.orders[row, :] | |
to_rk = [] | |
for (i, o) in enumerate(ords): | |
to_rk.append(self.Bs[i][o]) | |
out.append(reduce(func, to_rk)) | |
return BasisMatrix(out, self.orders, "expanded") | |
elif self.format == "expanded": | |
return self | |
else: | |
msg = f"Unknown format {self.format}... How'd you get this?" | |
raise ValueError(msg) | |
def filter(self, y): | |
ii = np.where(np.all(self.orders == 0, axis=1))[0] | |
if len(ii) == 0: | |
msg = "Zero order basis not found in BasisMatrix. can't filter" | |
raise ValueError(msg) | |
if self.format == "direct": | |
return self.to_expanded().filter(y) | |
if self.format == "expanded": | |
B = self.Bs[ii[0]] | |
if spa.issparse(B): | |
# TODO: figure out which solver to use | |
return spla.spsolve(B, y) | |
else: | |
return la.lstsq(B, y) | |
if self.format == "tensor": | |
b = [self.Bs[i][0] for i in range(len(self.Bs))] | |
d = len(b) | |
n = [b[_].shape[1] for _ in range(d)] | |
if np.prod(n) != y.shape[0]: | |
raise ValueError("BasisMatrix and y are unconformable") | |
z = y.reshape((-1, y.shape[0])) | |
mm = 1 | |
for i in range(d): | |
m = int(z.size / n[i]) | |
z = z.reshape((n[i], m)) | |
if spa.issparse(b[i]): | |
z = spla.spsolve(b[i], z) | |
else: | |
z = la.lstsq(b[i], z) | |
mm *= z.shape[0] | |
return z.reshape(y.shape) | |
# ----- # | |
# Tests # | |
# ----- # | |
def test_row_kron(): | |
A = np.array([ | |
[ | |
0.8264199353134758, 0.3348224988763786, 0.683690418471337, | |
0.5021681483282827 | |
], | |
[ | |
0.8083856161115917, 0.8409840370611767, 0.819167773060582, | |
0.6669167460999958 | |
], | |
[ | |
0.30063303241442885, 0.6608093259051855, 0.3891429787229923, | |
0.018140909322307497 | |
], | |
[ | |
0.9497853372445253, 0.11998096796907887, 0.7869829064282052, | |
0.3769530502240095 | |
], | |
[ | |
0.10066745535712318, 0.7476934606909804, 0.5850579436620573, | |
0.35147541422531225 | |
], | |
]) | |
B = np.array([ | |
[0.5583446948325421, 0.8480075982050823, 0.5531501558443193], | |
[0.0102169807402257, 0.12707447035237185, 0.4130074757062898], | |
[0.7226883352203841, 0.011663092121530827, 0.6397819036107639], | |
[0.8642752828232878, 0.4379128600951878, 0.23778812232529578], | |
[0.7871169627111041, 0.175590791212894, 0.4047727247721271], | |
]) | |
want = np.array([ | |
[ | |
0.46142718658613185, 0.7008103844539801, 0.4571343160115014, | |
0.1869463659582008, 0.2839320230971817, 0.18520711743365323, | |
0.3817349180613117, 0.5797746696837062, 0.3781834615266879, | |
0.2803829215329778, 0.42584240535896056, 0.27777438950784283 | |
], | |
[ | |
0.00825926027048762, 0.1027251740078563, 0.3338693027075223, | |
0.0085923177094913, 0.10686760108434849, 0.3473326942559215, | |
0.008369421360373544, 0.10409531089140539, 0.33832241413169384, | |
0.006813875550237651, 0.08474809227978422, 0.2754416018130119 | |
], | |
[ | |
0.21726398570783936, 0.0035063107518246467, 0.19233957376637978, | |
0.4775591916365227, 0.007707080042798866, 0.4227738484513653, | |
0.28122909145602065, 0.00453861040929317, 0.24896663570415903, | |
0.013110223557522352, 0.0002115790965944097, 0.011606225499456144 | |
], | |
[ | |
0.8208759909684239, 0.41592321350922257, 0.22584767195547348, | |
0.10369658502488747, 0.052541208840328445, 0.02853004908813872, | |
0.6801698740303301, 0.3446299353999989, 0.18713518762166687, | |
0.3257912040934568, 0.16507258834520097, 0.08963495801756013 | |
], | |
[ | |
0.07923706170455447, 0.017676278135545946, 0.040747440200779216, | |
0.5885222058180388, 0.1312880863474361, 0.30264591937818947, | |
0.46050903162528284, 0.10273078723300941, 0.2368154980056686, | |
0.27665226051265496, 0.06171584607570224, 0.1422676611063917 | |
], | |
]) | |
assert np.allclose(row_kron(A, B), want) | |
assert np.allclose(row_kron(spa.csr_matrix(A), B).toarray(), want) | |
assert np.allclose( | |
row_kron(spa.csr_matrix(A), spa.csr_matrix(B)).toarray(), want) | |
assert np.allclose(row_kron(A, spa.csr_matrix(B)).toarray(), want) | |
for i in range(5): | |
As = spa.random(10, 5, 0.4) | |
Bs = spa.random(10, 3, 0.4) | |
A = As.toarray() | |
B = Bs.toarray() | |
# make sure dense-dense version is correct. Then we can compare | |
# the other three versions against the dense-dense one | |
have = row_kron(A, B) | |
for row in range(As.shape[0]): | |
assert np.allclose(have[row, :], np.kron(A[row, :], B[row, :])) | |
assert np.allclose(row_kron(A, Bs).toarray(), have) | |
assert np.allclose(row_kron(As, B).toarray(), have) | |
assert np.allclose(row_kron(As, Bs).toarray(), have) | |
def test_full_kron(): | |
for i in range(5): | |
As = spa.random(10, 5, 0.5) | |
Bs = spa.random(10, 10, 0.5) | |
A = As.toarray() | |
B = Bs.toarray() | |
want = np.kron(A, B) | |
assert np.allclose(full_kron(A, B), want) | |
assert np.allclose(full_kron(As, B).toarray(), want) | |
assert np.allclose(full_kron(A, Bs).toarray(), want) | |
assert np.allclose(full_kron(As, Bs).toarray(), want) | |
def prof_row_kron_csr_csr(): | |
A = spa.random(500, 10, 0.2).tocsr() | |
B = spa.random(500, 10, 0.2).tocsr() | |
for _ in range(5000): | |
row_kron(A, B) | |
def test_lookup(): | |
table1 = np.array([1.0, 4.0]) | |
table2 = np.array( | |
[1.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0]) | |
x = np.array([0.5, 1.0, 1.5, 4.0, 5.5]) | |
x2 = np.array([0.5, 2.0]) | |
assert all(lookup(table1, x, 0) == np.array([-1, 0, 0, 1, 1])) | |
assert all(lookup(table1, x, 1) == np.array([0, 0, 0, 1, 1])) | |
assert all(lookup(table1, x, 2) == np.array([-1, 0, 0, 0, 0])) | |
assert all(lookup(table1, x, 3) == np.array([0, 0, 0, 0, 0])) | |
assert all(lookup(table2, x, 0) == np.array([-1, 2, 2, 11, 11])) | |
assert all(lookup(table2, x, 1) == np.array([2, 2, 2, 11, 11])) | |
assert all(lookup(table2, x, 2) == np.array([-1, 2, 2, 7, 7])) | |
assert all(lookup(table2, x, 3) == np.array([2, 2, 2, 7, 7])) | |
assert all(lookup(np.array([1.0]), x2, 0) == np.array([-1, 0])) | |
assert all(lookup(np.array([1.0]), x2, 1) == np.array([0, 0])) | |
assert all(lookup(np.array([1.0]), x2, 2) == np.array([-1, 0])) | |
assert all(lookup(np.array([1.0]), x2, 3) == np.array([0, 0])) | |
if __name__ == '__main__': | |
p1 = Spline(np.linspace(0, 1, 10), 0, 1) | |
p2 = Lin(np.linspace(0, 1, 10), 0) | |
b = Basis(p1, p2) | |
bm = b.basis_matrix(b.nodes[1], format="tensor") | |
y = np.random.randn(b.nodes[0].shape[0]) | |
bm.filter(y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment