Skip to content

Instantly share code, notes, and snippets.

@cheind
Last active January 11, 2021 07:44
Show Gist options
  • Save cheind/35dfb3e67263dfdf7da80366c18db320 to your computer and use it in GitHub Desktop.
Save cheind/35dfb3e67263dfdf7da80366c18db320 to your computer and use it in GitHub Desktop.
Transactional-like (undoable) matrix operations (row delete, column permute) implemented based on permutation matrices
import numpy as np
from itertools import count
def perm_matrix(perm_indices):
'''Returns the permutation matrix corresponding to given permutation indices
Here `perm_indices` defines the permutation order in the following sense:
value `j` at index `i` will move row/column `j` of the original matrix to
row/column `i`in the permuated matrix P*M/M*P^T.
Params
------
perm_indices: N
permutation order
'''
N = len(perm_indices)
pm = np.empty((N,N), dtype=np.int32)
for i,j in enumerate(perm_indices):
pm[i] = basis_vec(j, N, dtype=np.int32)
return pm
def binary_perm_matrix(i, j, N):
'''Returns permutation matrix that exchanges row/column i and j.'''
ids = np.arange(N)
ids[i] = j
ids[j] = i
return perm_matrix(ids)
def basis_vec(i, n, dtype=None):
'''Returns the standard basis vector e_i in R^n.'''
e = np.zeros(n, dtype=dtype)
e[i] = 1
return e
class MatrixState:
def __init__(self, m):
self.R, self.C = m.shape
self.m = m
self.dr = 0 # Number of deleted rows
self.dc = 0 # Number of deleted cols
self.rp = np.eye(self.R, dtype=m.dtype) # Sequence of row permutations
self.cp = np.eye(self.C, dtype=m.dtype) # Sequence of col permutations
self.history = []
@property
def matrix(self):
'''Returns matrix as represented by the current state'''
m = self.rp @ self.m @ self.cp
return m[:self.R-self.dr, :self.C-self.dc]
@property
def indices(self):
'''Returns original row and column indices of the current matrix state.'''
return np.where(self.rp)[1][:self.R-self.dr], np.where(self.cp.T)[1][:self.C-self.dc]
def transaction(self):
return MatrixTransaction(self)
class UndoableMatrixOp:
def apply(self, state):
raise NotImplementedError()
def undo(self, state):
raise NotImplementedError()
class SwapRowsOp(UndoableMatrixOp):
def __init__(self, i, j):
self.ids = (i,j)
def apply(self, state):
self.p = binary_perm_matrix(self.ids[0], self.ids[1], state.R)
state.rp = self.p @ state.rp
def undo(self, state):
state.rp = self.p.T @ state.rp
class SwapColsOp(UndoableMatrixOp):
def __init__(self, i, j):
self.ids = (i,j)
def apply(self, state):
self.p = binary_perm_matrix(self.ids[0], self.ids[1], state.C)
state.cp = state.cp @ self.p
def undo(self, state):
state.cp = state.cp @ self.p.T
class DeleteOp(UndoableMatrixOp):
def __init__(self, ids, rows=True):
if isinstance(ids, int):
ids = [ids]
self.ids = ids
self.rows = rows
def apply(self, state):
self.p = DeleteOp.delete_perm_matrix(state, self.ids, rows=self.rows)
if self.rows:
state.rp = self.p @ state.rp
state.dr += len(self.ids)
else:
state.cp = state.cp @ self.p
state.dc += len(self.ids)
def undo(self, state):
if self.rows:
state.dr -= len(self.ids)
state.rp = self.p.T @ state.rp
else:
state.dc -= len(self.ids)
state.cp = state.cp @ self.p.T
@staticmethod
def delete_perm_matrix(state, ids, rows=True):
'''Returns the permutation matrix that moves deleted rows/columns to the end of the array.'''
N = state.R if rows else state.C
d = state.dr if rows else state.dc
pids = np.arange(N).astype(dtype=np.int32) # each entry holds target row index
upper = N - d # ignore already deleted ones
rcnt = count(upper-1, -1)
cnt = count(0, 1)
# We reorder the values i 0..upper in that we assign the value i
# to index w, where w is chosen from increasing numbers when i is
# not in the deleted map, otherwise we select w to be the next possible
# index from the back.
for i in range(0,upper):
w = next(rcnt) if i in ids else next(cnt)
pids[w] = i
p = perm_matrix(pids)
return p if rows else p.T
class MatrixTransaction:
def __init__(self, matrix_state):
self.matrix_state = matrix_state
self.committed = None
self.ops = None
def __enter__(self):
self.committed = False
self.ops = []
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.committed:
self._undo_all()
def commit(self):
self.committed = True
def swap_rows(self, i, j):
return self._apply(SwapRowsOp(i, j))
def swap_cols(self, i, j):
return self._apply(SwapColsOp(i, j))
def delete_rows(self, ids):
return self._apply(DeleteOp(ids, rows=True))
def delete_cols(self, ids):
return self._apply(DeleteOp(ids, rows=False))
def _undo(self):
op = self.ops.pop()
op.undo(self.matrix_state)
return self
def _undo_all(self):
while len(self.ops) > 0:
self._undo()
def _apply(self, op):
op.apply(self.matrix_state)
self.ops.append(op)
import numpy as np
from numpy.testing import assert_allclose
import matrix_transactions as u
def test_matrix_ops():
m = np.arange(9).astype(np.float32).reshape(3,3)
ms = u.MatrixState(m)
with ms.transaction() as t:
t.swap_cols(0,2)
t.swap_rows(0,1)
t.delete_rows(2)
assert_allclose(ms.matrix, [[5,4,3],[2,1,0]])
# no commit
assert_allclose(ms.matrix, m)
ms = u.MatrixState(m)
with ms.transaction() as t:
t.swap_cols(0,2)
t.swap_rows(0,1)
t.delete_rows(2)
assert_allclose(ms.matrix, [[5,4,3],[2,1,0]])
t.commit()
# no commit
assert_allclose(ms.matrix, [[5,4,3],[2,1,0]])
m = np.arange(10).astype(np.float32).reshape(5,2)
ms = u.MatrixState(m)
with ms.transaction() as t:
t.delete_rows([2,3,0])
assert_allclose(ms.matrix, [[2,3],[8,9]])
assert_allclose(ms.indices[0], [1,4])
assert_allclose(ms.indices[1], [0,1])
with ms.transaction() as tt:
tt.delete_rows(0)
assert_allclose(ms.matrix, [[8,9]])
assert_allclose(ms.indices[0], [4])
assert_allclose(ms.indices[1], [0,1])
assert_allclose(ms.matrix, m)
m = np.arange(20).astype(np.float32).reshape(4,5)
ms = u.MatrixState(m)
with ms.transaction() as t:
t.delete_cols([1,2,3])
t.delete_rows([0,2])
assert_allclose(ms.matrix, [[5,9],[15,19]])
assert_allclose(ms.indices[0], [1,3])
assert_allclose(ms.indices[1], [0,4])
assert_allclose(ms.matrix, m)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment