Skip to content

Instantly share code, notes, and snippets.

@sglyon
Last active December 29, 2023 17:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save sglyon/7790b35049342be7cc688549ad69285c to your computer and use it in GitHub Desktop.
Save sglyon/7790b35049342be7cc688549ad69285c to your computer and use it in GitHub Desktop.
basis matrices in python
"""
This is basis matrices in python
"""
import numpy as np
from numba import njit
import quantecon as qe
import scipy.sparse as spa
import scipy.sparse.linalg as spla
import scipy.linalg as la
from functools import reduce
@njit(cache=True)
def row_kron_dense_csr(m, na, nb, A, Bval, Bind, Bptr, outval, outind, outptr):
outptr[0] = 0
ix = 0
for row in range(m):
for col_a in range(na):
_a = A[row, col_a]
for _ib in range(Bptr[row], Bptr[row + 1]):
col_b = Bind[_ib]
_b = Bval[_ib]
outval[ix] = _a * _b
outind[ix] = col_a * nb + col_b
ix += 1
outptr[row + 1] = ix
return outval, outind, outptr
@njit(cache=True)
def row_kron_dense_dense(m, na, nb, A, B, out):
for row in range(m):
ix = 0
for col_a in range(na):
_a = A[row, col_a]
for col_b in range(nb):
out[row, ix] = _a * B[row, col_b]
ix += 1
return out
@njit(cache=True)
def row_kron_csr_dense(m, na, nb, Aval, Aind, Aptr, B, outval, outind, outptr):
outptr[0] = 0
ix = 0
for row in range(m):
for _ia in range(Aptr[row], Aptr[row + 1]):
col_a = Aind[_ia]
_a = Aval[_ia]
for col_b in range(nb):
outval[ix] = _a * B[row, col_b]
outind[ix] = col_a * nb + col_b
ix += 1
outptr[row + 1] = ix
return outval, outind, outptr
@njit(cache=True)
def row_kron_csr_csr(m, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr, outval,
outind, outptr):
outptr[0] = 0
ix = 0
for row in range(m):
for _ia in range(Aptr[row], Aptr[row + 1]):
col_a = Aind[_ia]
_a = Aval[_ia]
for _ib in range(Bptr[row], Bptr[row + 1]):
col_b = Bind[_ib]
_b = Bval[_ib]
outval[ix] = _a * _b
outind[ix] = col_a * nb + col_b
ix += 1
outptr[row + 1] = ix
return outval, outind, outptr
@njit(cache=True)
def nnz_row_one_csr(nrow, m_dense, csr_ptr):
out = 0
for row in range(nrow):
out += m_dense * (csr_ptr[row + 1] - csr_ptr[row])
return out
@njit(cache=True)
def nnz_row_csr_csr(m, Aptr, Bptr):
out = 0
for row in range(m):
out += (Aptr[row + 1] - Aptr[row]) * (Bptr[row + 1] - Bptr[row])
return out
def row_kron(A, B):
ma = A.shape[0]
mb = B.shape[0]
na = A.shape[1]
nb = B.shape[1]
assert ma == mb
if spa.issparse(A):
if not spa.isspmatrix_csr(A):
As = A.tocsr()
else:
As = A
out = spa.csr_matrix((ma, na * nb))
Aval = As.data
Aind = As.indices
Aptr = As.indptr
outptr = np.zeros(ma + 1, dtype="int32")
if spa.issparse(B):
if not spa.isspmatrix_csr(B):
Bs = B.tocsr()
else:
Bs = B
Bval = Bs.data
Bind = Bs.indices
Bptr = Bs.indptr
nnz = nnz_row_csr_csr(ma, Aptr, Bptr)
outval = np.zeros(nnz)
outind = np.zeros(nnz, dtype="int32")
row_kron_csr_csr(ma, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr,
outval, outind, outptr)
else:
nnz = nnz_row_one_csr(ma, nb, Aptr)
outval = np.zeros(nnz)
outind = np.zeros(nnz, dtype="int32")
row_kron_csr_dense(ma, na, nb, Aval, Aind, Aptr, B, outval, outind,
outptr)
out.indptr = outptr
out.data = outval
out.indices = outind
return out
else:
if spa.issparse(B):
Bs = B.tocsr()
Bval = Bs.data
Bind = Bs.indices
Bptr = Bs.indptr
nnz = nnz_row_one_csr(ma, na, Bptr)
outptr = np.zeros(ma + 1, dtype="int32")
outval = np.zeros(nnz)
outind = np.zeros(nnz, dtype="int32")
row_kron_dense_csr(ma, na, nb, A, Bval, Bind, Bptr, outval, outind,
outptr)
return spa.csr_matrix(
(outval, outind, outptr), shape=(ma, na * nb), copy=False)
else:
out = np.empty((ma, na * nb), dtype=A.dtype)
return row_kron_dense_dense(ma, na, nb, A, B, out)
@njit(cache=True)
def full_kron_dense_csr(ma, mb, na, nb, A, Bval, Bind, Bptr, outval, outind, outptr):
outptr[0] = 0
ix = 0
for ia in range(ma):
for ib in range(mb):
row = ia*mb + ib
for ja in range(na):
aij = A[ia, ja]
col_offset = ja*nb
for _ind_b in range(Bptr[ib], Bptr[ib+1]):
jb = Bind[_ind_b]
bij = Bval[_ind_b]
outval[ix] = aij*bij
outind[ix] = col_offset + jb
ix += 1
outptr[row + 1] = ix
return outval, outind, outptr
@njit(cache=True)
def full_kron_dense_dense(ma, mb, na, nb, A, B, out):
for ia in range(ma):
for ib in range(mb):
row = ia*mb + ib
for ja in range(na):
aij = A[ia, ja]
col_offset = ja*nb
for jb in range(nb):
out[row, col_offset + jb] = aij*B[ib, jb]
return out
@njit(cache=True)
def full_kron_csr_dense(ma, mb, na, nb, Aval, Aind, Aptr, B, outval, outind, outptr):
outptr[0] = 0
ix = 0
for ia in range(ma):
for ib in range(mb):
row = ia*mb + ib
for _ind_a in range(Aptr[ia], Aptr[ia+1]):
ja = Aind[_ind_a]
aij = Aval[_ind_a]
col_offset = ja*nb
for jb in range(nb):
outval[ix] = aij*B[ib, jb]
outind[ix] = col_offset + jb
ix += 1
outptr[row+1] = ix
return outval, outind, outptr
@njit(cache=True)
def full_kron_csr_csr(ma, mb, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr, outval,
outind, outptr):
outptr[0] = 0
ix = 0
for ia in range(ma):
for ib in range(mb):
row = ia*mb + ib
for _ind_a in range(Aptr[ia], Aptr[ia+1]):
ja = Aind[_ind_a]
aij = Aval[_ind_a]
col_offset = ja*nb
for _ind_b in range(Bptr[ib], Bptr[ib+1]):
jb = Bind[_ind_b]
bij = Bval[_ind_b]
outval[ix] = aij*bij
outind[ix] = col_offset + jb
ix += 1
outptr[row + 1] = ix
return outval, outind, outptr
def full_kron(A, B):
ma = A.shape[0]
mb = B.shape[0]
na = A.shape[1]
nb = B.shape[1]
if spa.issparse(A):
if not spa.isspmatrix_csr(A):
As = A.tocsr()
else:
As = A
out = spa.csr_matrix((ma * mb, na * nb))
Aval = As.data
Aind = As.indices
Aptr = As.indptr
outptr = np.zeros(ma*mb + 1, dtype="int32")
if spa.issparse(B):
if not spa.isspmatrix_csr(B):
Bs = B.tocsr()
else:
Bs = B
Bval = Bs.data
Bind = Bs.indices
Bptr = Bs.indptr
nnz = Bval.size * Aval.size
outval = np.zeros(nnz, dtype=A.dtype)
outind = np.zeros(nnz, dtype="int32")
full_kron_csr_csr(ma, mb, na, nb, Aval, Aind, Aptr, Bval, Bind, Bptr,
outval, outind, outptr)
else:
nnz = Aval.size * B.size
outval = np.zeros(nnz)
outind = np.zeros(nnz, dtype="int32")
full_kron_csr_dense(ma, mb, na, nb, Aval, Aind, Aptr, B, outval, outind,
outptr)
out.indptr = outptr
out.data = outval
out.indices = outind
return out
else:
if spa.issparse(B):
if not spa.isspmatrix_csr(B):
Bs = B.tocsr()
else:
Bs = B
Bval = Bs.data
Bind = Bs.indices
Bptr = Bs.indptr
nnz = A.size * Bval.size
outptr = np.zeros(ma*mb + 1, dtype="int32")
outval = np.zeros(nnz, dtype=A.dtype)
outind = np.zeros(nnz, dtype="int32")
full_kron_dense_csr(ma, mb, na, nb, A, Bval, Bind, Bptr, outval, outind,
outptr)
out = spa.csr_matrix((ma * mb, na * nb))
out.indptr = outptr
out.indices = outind
out.data = outval
return out
else:
out = np.empty((ma*mb, na * nb), dtype=A.dtype)
return full_kron_dense_dense(ma, mb, na, nb, A, B, out)
@njit(cache=True)
def lookup(table, x, p):
n = table.size
m = x.size
out = np.empty((m, ), np.int32)
# lower enbound adjustment
numfirst = 1
while numfirst < n and table[numfirst] == table[0]:
numfirst += 1
# Upper endpoint adjustment
if p >= 2:
n -= 1
for i in range(n - 1, 0, -1):
if table[i] == table[-1]:
n -= 1
else:
break
n1 = n - 1
n2 = n - 2
# handle 1-value lists separately
if n - numfirst < 1:
if p == 1 or p == 3:
for i in range(m):
out[i] = numfirst - 1
else:
for i in range(m):
if table[0] <= x[i]:
out[i] = numfirst - 1
else:
out[i] = -1
return out
jlo = 0
for i in range(m):
inc = 1
xi = x[i]
if xi >= table[jlo]:
jhi = jlo + 1
while xi >= table[jhi]:
jlo = jhi
jhi += inc
if jhi >= n:
break
else:
inc += inc
else:
jhi = jlo
jlo -= 1
while xi < table[jlo]:
jhi = jlo
jlo -= inc
if jlo < 0:
jlo = -1
break
else:
inc += inc
while jhi - jlo > 1:
j = (jhi + jlo) // 2
# if j >= n:
# j = n-1
if j < n and xi >= table[j]:
jlo = j
else:
jhi = j
out[i] = jlo
if jlo < 0:
jlo = 0
if p == 1 or p == 3:
out[i] = numfirst - 1
if jlo == n1:
jlo = n2
return out
class LinearBasis(object):
def __call__(self, x, orders=0):
x = np.atleast_1d(np.asarray(x))
out = self.basis_matrix(x, orders)
if isinstance(orders, (int, np.int32, np.int64)):
return out[0]
else:
return out
class Spline(LinearBasis):
def __init__(self, breaks, evennum=0, k=3):
if k <= 0:
raise ValueError("spline order must be positive")
# check to see if breaks is a list or tuple
if isinstance(breaks, (list, tuple)):
if len(breaks) == 3:
# assume it is of the form (lb, ub, n)
breaks = np.linspace(*breaks)
else:
# assume the breaks are given directly.
breaks = np.asarray(breaks)
if isinstance(breaks, np.ndarray):
if not all(breaks == np.sort(breaks)):
raise ValueError("Breaks must be sorted")
if breaks.size < 2:
raise ValueError("must have at least 2 breakpoints")
if evennum == 0:
if breaks.size == 2:
evennum = 2
else:
if breaks.size == 2:
breaks = np.linspace(breaks[0], breaks[2], evennum)
else:
m = "Breakpoint sequence must contain 2 values when"
raise ValueError(m + " evennum > 0")
self.breaks = breaks
self.evennum = evennum
self.k = k
self.n = len(breaks) + k - 1
else:
raise ValueError("Couldn't interpret input form of breaks")
@property
def nodes(self):
breaks, k = self.breaks, self.k
a = breaks[0] # 20
b = breaks[-1] # 21
n = self.n
x = np.cumsum(np.concatenate((np.full(k, a), breaks,
np.full(k, b)))) # 23
x = (x[k:n + k] - x[:n]) / k # 24
x[0] = a # 25
x[-1] = b # 26
return x
def derivative_op(self, order=1):
raise NotImplementedError()
def basis_matrix(self, x, order=0):
breaks, k = self.breaks, self.k
order = np.atleast_1d(order)
# error handling
if any(order >= k):
raise ValueError("Order of diff must be < k")
if x.ndim > 1:
if any(np.asarray(x.shape[1:]) > 1):
raise ValueError("x must be a vector")
else:
# flatten
x = x.reshape(-1)
m = len(x)
minorder = order.min()
# Augment the breakpoint sequence 57-59
n = self.n
a = breaks[0]
b = breaks[-1]
augbreaks = np.concatenate((np.full(k - minorder, a), breaks, np.full(
k - minorder, b)))
ind = lookup(augbreaks, x, 3) # 69
bas = np.zeros((k - minorder + 1, m))
bas[0, :] = 1.0
B = []
if order.max() > 0:
D = self.derivative_op(order.max())
if minorder < 0:
I = self.derivative_op(minorder)
for j in range(1, k - minorder + 1):
for jj in range(j, 0, -1):
b0 = augbreaks[ind + jj - j]
b1 = augbreaks[ind + jj]
temp = bas[jj - 1, :] / (b1 - b0)
bas[jj, :] = (x - b0) * temp + bas[jj, :]
bas[jj - 1, :] = (b1 - x) * temp
# bas now contains the `j` order spline basis
ii = np.argwhere(order == k - j)
if len(ii) > 0:
ii = ii[0][0]
# put the values into appropriate rows/columns of sparse
# matrix
row_ind_ptr = np.arange(
0, m * (k + 1) + 1, k + 1, dtype="int32")
colval = np.arange(-k + order[ii], 1,
dtype="int32") - (order[ii] - minorder)
colval = (colval[None, :] + ind[:, None]).ravel(order="C")
val = np.ravel(bas[:k - order[ii] + 1, :], order="F")
B_ii = spa.csr_matrix((m, n - order[ii]))
B_ii.indices = colval
B_ii.data = val
B_ii.indptr = row_ind_ptr
if order[ii] > 0:
B.append(B_ii.dot(D[order[ii]]))
elif order[ii] < 0:
B.append(B_ii.dot(I[order[ii]]))
else:
B.append(B_ii)
return B
class Lin(LinearBasis):
def __init__(self, breaks, evennum=0):
# passed array directly. Make sure they are sorted
if isinstance(breaks, np.ndarray):
if not all(breaks == np.sort(breaks)):
raise ValueError("breaks must be sorted")
self.breaks = breaks
# could be of form (lb, ub, n)
if isinstance(breaks, tuple) and len(breaks) == 3:
self.breaks = np.linspace(*breaks)
self.evennum = evennum
@property
def nodes(self):
return self.breaks
def basis_matrix(self, x, orders=0):
m = x.shape[0]
n = self.breaks.size
if not isinstance(orders, (int, np.int32, np.int64)):
raise NotImplementedError("Only order=0 for Lin is implemented")
if self.evennum != 0:
raise NotImplementedError("I haven't done this bit yet")
ind = lookup(self.breaks, x, 3)
z = np.empty(m)
for i in range(m):
z[i] = ((x[i] - self.breaks[ind[i]]) /
(self.breaks[ind[i] + 1] - self.breaks[ind[i]]))
row_ind_ptr = np.arange(0, 2 * m + 1, 2, dtype="int32")
colval = np.empty((2 * m,), dtype="int32")
colval[::2] = ind
colval[1::2] = ind + 1
val = np.empty((2 * m,), dtype=x.dtype)
val[::2] = 1 - z
val[1::2] = z
out = spa.csr_matrix((m, n))
out.indices = colval
out.data = val
out.indptr = row_ind_ptr
return [out]
class Basis(object):
def __init__(self, *params):
for p in params:
if not isinstance(p, LinearBasis):
msg = "Can only construct a basis using instances of "
raise ValueError(msg + "LinearBasis")
self.params = params
self._nodes = None
self.n = len(params)
def _b_mat_tensor(self):
pass
def _basis_mat_direct(self, x, orders):
out = []
for col in range(orders.shape[1]):
Bs = {}
needs = np.unique(orders[:, col])
for o in needs:
Bs[o] = self.params[col](x[:, col], orders=o)
out.append(Bs)
return out
def _basis_mat_tensor(self, x, orders):
out = []
for col in range(orders.shape[1]):
Bs = {}
needs = np.unique(orders[:, col])
for o in needs:
Bs[0] = self.params[col](x[col], orders=o)
out.append(Bs)
return out
def basis_matrix(self, x, orders=0, format="direct"):
orders = self._check_order(orders)
if format == "direct":
if not isinstance(x, np.ndarray):
msg = "x must be a numpy array for evaluation in direct form"
raise ValueError(x)
if x.ndim == 1:
if x.size == self.n:
x = x.reshape((1, self.n))
else:
msg = f"A 1d array was passed, must be shape (1, {self.n})"
raise ValueError(msg)
if x.ndim == 2 and x.shape[1] != self.n:
msg = f"A 2d array was passed, have {self.n} columns"
raise ValueError(msg)
Bs = self._basis_mat_direct(x, orders)
return BasisMatrix(Bs, orders, "direct")
if format == "expanded":
return self.basis_matrix(x, orders, format="direct").to_expanded()
if format == "tensor":
if len(x) != self.n:
msg = f"For tensor format and {self.n} dim basis, must pass "
raise ValueError(msg + f"{self.n} arrays")
Bs = self._basis_mat_tensor(x, orders)
return BasisMatrix(Bs, orders, "tensor")
msg = "Format must be one of ['tensor', 'direct', 'expanded']."
raise ValueError(msg + f"Found {format}")
@property
def nodes(self):
if self._nodes is None:
dim_nodes = [x.nodes for x in self.params]
grid = qe.cartesian(dim_nodes)
self._nodes = (grid, dim_nodes)
return self._nodes
def _check_order(self, order):
order = np.asarray(order)
if order.ndim == 0: # scalar case
return np.full(shape=(1, self.n), fill_value=order)
if order.ndim == 1: # vector case
if order.size != self.n:
msg = f"Expected {n} orders, found ${order.size}"
raise ValueError(msg)
else:
return order.reshape((1, self.n))
elif order.ndim == 2:
if order.size % self.n != 0:
msg = f"Order must have a multiple of {self.n} elements"
raise ValueError(msg)
else:
return order
else:
msg = f"order of dimension {order.ndim}, can only handle ndim <= 2"
raise ValueError(msg)
class BasisMatrix(object):
def __init__(self, Bs, orders, format):
if not isinstance(orders, np.ndarray):
msg = "orders must be a numpy array"
raise ValueError(msg)
if orders.ndim != 2:
msg = "orders must have 2 dimensions"
raise ValueError(msg)
if format in ["direct", "tensor"]:
if not isinstance(Bs, (tuple, list)):
if isinstance((Bs[0], dict)):
msg = "When using direct or tensor form, pass Bs as a list"
raise ValueError(msg + " of dict")
self.Bs = np.asarray(Bs)
self.orders = orders
# if orders.shape[1] !=
self.format = format
def to_direct(self):
if self.format == "direct":
return self
elif self.format == "tensor":
raw_inds = []
for i in range(len(self.Bs)):
n = self.Bs[i].values().__iter__().__next__().shape[0]
raw_inds.append(np.arange(n, dtype="int32"))
inds = qe.cartesian(raw_inds)
out = []
for i in range(len(self.Bs)):
Bs = {}
for (key, val) in self.Bs[i].items():
Bs[key] = val[inds[:, i], :]
out.append(Bs)
return BasisMatrix(out, self.orders, "direct")
else:
msg = "Cannot convert from expanded to direct form"
raise ValueError(msg)
def to_expanded(self):
if self.format == "direct" or self.format == "tensor":
func = row_kron if self.format == "direct" else full_kron
out = []
for row in range(self.orders.shape[0]):
ords = self.orders[row, :]
to_rk = []
for (i, o) in enumerate(ords):
to_rk.append(self.Bs[i][o])
out.append(reduce(func, to_rk))
return BasisMatrix(out, self.orders, "expanded")
elif self.format == "expanded":
return self
else:
msg = f"Unknown format {self.format}... How'd you get this?"
raise ValueError(msg)
def filter(self, y):
ii = np.where(np.all(self.orders == 0, axis=1))[0]
if len(ii) == 0:
msg = "Zero order basis not found in BasisMatrix. can't filter"
raise ValueError(msg)
if self.format == "direct":
return self.to_expanded().filter(y)
if self.format == "expanded":
B = self.Bs[ii[0]]
if spa.issparse(B):
# TODO: figure out which solver to use
return spla.spsolve(B, y)
else:
return la.lstsq(B, y)
if self.format == "tensor":
b = [self.Bs[i][0] for i in range(len(self.Bs))]
d = len(b)
n = [b[_].shape[1] for _ in range(d)]
if np.prod(n) != y.shape[0]:
raise ValueError("BasisMatrix and y are unconformable")
z = y.reshape((-1, y.shape[0]))
mm = 1
for i in range(d):
m = int(z.size / n[i])
z = z.reshape((n[i], m))
if spa.issparse(b[i]):
z = spla.spsolve(b[i], z)
else:
z = la.lstsq(b[i], z)
mm *= z.shape[0]
return z.reshape(y.shape)
# ----- #
# Tests #
# ----- #
def test_row_kron():
A = np.array([
[
0.8264199353134758, 0.3348224988763786, 0.683690418471337,
0.5021681483282827
],
[
0.8083856161115917, 0.8409840370611767, 0.819167773060582,
0.6669167460999958
],
[
0.30063303241442885, 0.6608093259051855, 0.3891429787229923,
0.018140909322307497
],
[
0.9497853372445253, 0.11998096796907887, 0.7869829064282052,
0.3769530502240095
],
[
0.10066745535712318, 0.7476934606909804, 0.5850579436620573,
0.35147541422531225
],
])
B = np.array([
[0.5583446948325421, 0.8480075982050823, 0.5531501558443193],
[0.0102169807402257, 0.12707447035237185, 0.4130074757062898],
[0.7226883352203841, 0.011663092121530827, 0.6397819036107639],
[0.8642752828232878, 0.4379128600951878, 0.23778812232529578],
[0.7871169627111041, 0.175590791212894, 0.4047727247721271],
])
want = np.array([
[
0.46142718658613185, 0.7008103844539801, 0.4571343160115014,
0.1869463659582008, 0.2839320230971817, 0.18520711743365323,
0.3817349180613117, 0.5797746696837062, 0.3781834615266879,
0.2803829215329778, 0.42584240535896056, 0.27777438950784283
],
[
0.00825926027048762, 0.1027251740078563, 0.3338693027075223,
0.0085923177094913, 0.10686760108434849, 0.3473326942559215,
0.008369421360373544, 0.10409531089140539, 0.33832241413169384,
0.006813875550237651, 0.08474809227978422, 0.2754416018130119
],
[
0.21726398570783936, 0.0035063107518246467, 0.19233957376637978,
0.4775591916365227, 0.007707080042798866, 0.4227738484513653,
0.28122909145602065, 0.00453861040929317, 0.24896663570415903,
0.013110223557522352, 0.0002115790965944097, 0.011606225499456144
],
[
0.8208759909684239, 0.41592321350922257, 0.22584767195547348,
0.10369658502488747, 0.052541208840328445, 0.02853004908813872,
0.6801698740303301, 0.3446299353999989, 0.18713518762166687,
0.3257912040934568, 0.16507258834520097, 0.08963495801756013
],
[
0.07923706170455447, 0.017676278135545946, 0.040747440200779216,
0.5885222058180388, 0.1312880863474361, 0.30264591937818947,
0.46050903162528284, 0.10273078723300941, 0.2368154980056686,
0.27665226051265496, 0.06171584607570224, 0.1422676611063917
],
])
assert np.allclose(row_kron(A, B), want)
assert np.allclose(row_kron(spa.csr_matrix(A), B).toarray(), want)
assert np.allclose(
row_kron(spa.csr_matrix(A), spa.csr_matrix(B)).toarray(), want)
assert np.allclose(row_kron(A, spa.csr_matrix(B)).toarray(), want)
for i in range(5):
As = spa.random(10, 5, 0.4)
Bs = spa.random(10, 3, 0.4)
A = As.toarray()
B = Bs.toarray()
# make sure dense-dense version is correct. Then we can compare
# the other three versions against the dense-dense one
have = row_kron(A, B)
for row in range(As.shape[0]):
assert np.allclose(have[row, :], np.kron(A[row, :], B[row, :]))
assert np.allclose(row_kron(A, Bs).toarray(), have)
assert np.allclose(row_kron(As, B).toarray(), have)
assert np.allclose(row_kron(As, Bs).toarray(), have)
def test_full_kron():
for i in range(5):
As = spa.random(10, 5, 0.5)
Bs = spa.random(10, 10, 0.5)
A = As.toarray()
B = Bs.toarray()
want = np.kron(A, B)
assert np.allclose(full_kron(A, B), want)
assert np.allclose(full_kron(As, B).toarray(), want)
assert np.allclose(full_kron(A, Bs).toarray(), want)
assert np.allclose(full_kron(As, Bs).toarray(), want)
def prof_row_kron_csr_csr():
A = spa.random(500, 10, 0.2).tocsr()
B = spa.random(500, 10, 0.2).tocsr()
for _ in range(5000):
row_kron(A, B)
def test_lookup():
table1 = np.array([1.0, 4.0])
table2 = np.array(
[1.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0])
x = np.array([0.5, 1.0, 1.5, 4.0, 5.5])
x2 = np.array([0.5, 2.0])
assert all(lookup(table1, x, 0) == np.array([-1, 0, 0, 1, 1]))
assert all(lookup(table1, x, 1) == np.array([0, 0, 0, 1, 1]))
assert all(lookup(table1, x, 2) == np.array([-1, 0, 0, 0, 0]))
assert all(lookup(table1, x, 3) == np.array([0, 0, 0, 0, 0]))
assert all(lookup(table2, x, 0) == np.array([-1, 2, 2, 11, 11]))
assert all(lookup(table2, x, 1) == np.array([2, 2, 2, 11, 11]))
assert all(lookup(table2, x, 2) == np.array([-1, 2, 2, 7, 7]))
assert all(lookup(table2, x, 3) == np.array([2, 2, 2, 7, 7]))
assert all(lookup(np.array([1.0]), x2, 0) == np.array([-1, 0]))
assert all(lookup(np.array([1.0]), x2, 1) == np.array([0, 0]))
assert all(lookup(np.array([1.0]), x2, 2) == np.array([-1, 0]))
assert all(lookup(np.array([1.0]), x2, 3) == np.array([0, 0]))
if __name__ == '__main__':
p1 = Spline(np.linspace(0, 1, 10), 0, 1)
p2 = Lin(np.linspace(0, 1, 10), 0)
b = Basis(p1, p2)
bm = b.basis_matrix(b.nodes[1], format="tensor")
y = np.random.randn(b.nodes[0].shape[0])
bm.filter(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment