Last active
March 1, 2023 09:58
-
-
Save 0x1F9F1/fcd3095a6fe56a323d41e13fdbc7bdb9 to your computer and use it in GitHub Desktop.
Python class for performing matrix operations over GF(2)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://wiki.sagemath.org/quickref?action=AttachFile&do=get&target=quickref-linalg.pdf | |
# https://github.com/emin63/pyfinite | |
# https://github.com/malb/m4ri | |
from functools import reduce | |
import sys | |
def extract_rows(table): | |
''' | |
Extracts a set linearly independent rows from a lookup table | |
''' | |
# Pick any set of linearly independent rows | |
pending = table.copy() | |
pending.remove(0) | |
# We could start picking these in any order, however doing it this way preserves the original order if no s-box is required | |
# pending.sort() | |
# random.shuffle(pending) | |
seen = [0] | |
rows = [] | |
while pending: | |
# For each bit of the input, pick a value from pending | |
row = pending[0] | |
rows.append(row) | |
# Update the seen using this new row | |
for v in seen[:]: | |
v ^= row | |
seen.append(v) | |
pending.remove(v) | |
rows = Matrix(rows) | |
# If the seen is not what we started with, perform row reduction | |
if seen != table: | |
rows = rows.rref() | |
return rows | |
def expand_rows(rows): | |
return [ i * rows for i in range(1 << rows.nrows()) ] | |
# def bit(v, index): | |
# return (v >> index) & 1 | |
# def gf2_sum(values): | |
# return reduce(lambda x, y: x ^ y, values, 0) | |
def to_bits(value, count): | |
return ((value >> i) & 1 for i in range(count)) | |
def from_bits(bits): | |
result = 0 | |
for i, bit in enumerate(bits): | |
result |= bit << i | |
return result | |
def gf2_row_mul(rows, value): | |
result = 0 | |
for i, row in enumerate(rows): | |
if (value >> i) & 1: | |
result ^= row | |
return result | |
if sys.version_info >= (3,10,0): | |
def parity(value): | |
return value.bit_count() & 1 | |
else: | |
def parity(value): | |
i = 1 | |
while True: | |
v = value >> i | |
if not v: | |
break | |
value ^= v | |
i <<= 1 | |
return value & 1 | |
def gf2_col_mul(rows, value): | |
result = 0 | |
for i, row in enumerate(rows): | |
result |= parity(row & value) << i | |
return result | |
# A matrix over GF(2) | |
# Each row is represented using an `int` | |
# Designed to mostly mirror sage-math | |
class Matrix: | |
__slots__ = ['_rows', '_ncols'] | |
def __init__(self, rows, ncols=None): | |
''' | |
Constructs a GF(2) matrix | |
Args: | |
rows: A sequence of integers representing the rows of the matrix | |
ncols: The number of columns (bits) in each row. Defaults to the minimum number of bits required to represent the largest row | |
''' | |
self._rows = list(rows) | |
nbits = max(self._rows).bit_length() | |
if ncols is None: | |
ncols = nbits | |
elif ncols < nbits: | |
raise ValueError(f'Asked for {ncols} columns, but need at least {nbits}') | |
self._ncols = ncols | |
# Returns the number of rows | |
def nrows(self): | |
return len(self._rows) | |
# Returns the number of columns | |
def ncols(self): | |
return self._ncols | |
# Returns a tuple containing the size (nrows, ncols) | |
def size(self): | |
return self.nrows(), self.ncols() | |
# Returns a list of the rows | |
def rows(self): | |
return list(self._rows) | |
# Returns an iterator over the rows | |
def iter_rows(self): | |
return iter(self._rows) | |
# Returns the row at i | |
def row(self, i): | |
return self._rows[i] | |
# Returns the row at i, as a generator of bits | |
def row_bits(self, i): | |
return to_bits(self.row(i), self.ncols()) | |
# Returns a list of the columns | |
def columns(self): | |
return list(self.iter_columns()) | |
# Returns a iterator over the columns | |
def iter_columns(self): | |
return (self.column(j) for j in range(self.ncols())) | |
# Returns the column at j | |
def column(self, j): | |
return from_bits(self.column_bits(j)) | |
# Returns the column at j, as a generator of bits | |
def column_bits(self, j): | |
if j >= self.ncols(): | |
raise IndexError('Column index is out of range') | |
return ((row >> j) & 1 for row in self.iter_rows()) | |
# Returns a copy of self | |
def copy(self): | |
return Matrix(self.iter_rows(), self.ncols()) | |
# Returns a transposed copy of self | |
def transpose(self): | |
return Matrix(self.columns(), self.nrows()) | |
# Returns whether self is square | |
def is_square(self): | |
return self.nrows() == self.ncols() | |
# Returns whether self is identity | |
def is_identity(self): | |
return self.is_square() and all(row == (1 << i) for i, row in enumerate(self.iter_rows())) | |
# Adds rows j to row i | |
def add_row(self, i, j): | |
self._rows[i] ^= self._rows[j] | |
# Swaps the values of rows i and j | |
def swap_rows(self, i, j): | |
self._rows[i], self._rows[j] = self._rows[j], self._rows[i] | |
def echelon_form(self, reduced=True, full_rank=False): | |
''' | |
Computes the row echelon form of self. | |
Args: | |
reduced: Whether to convert to reduced row echelon form | |
full_rank: Whether to halt early if a pivot is missing | |
Returns: | |
tuple ( | |
E: The row echelon form of self | |
T: An m by m matrix: `E == T * self` | |
p: Indices of the non-zero pivot columns of E: `rank = len(p)` | |
) | |
References: | |
https://en.wikipedia.org/wiki/Row_echelon_form#Reduced_row_echelon_form | |
https://en.wikipedia.org/wiki/Invertible_matrix#Gaussian_elimination | |
https://uk.mathworks.com/help/matlab/ref/rref.html | |
''' | |
# Row reduction is performed on E | |
E = self.copy() | |
# The transform performed on self to produce E | |
T = identity_matrix(E.nrows()) | |
# The indices of the non-zero pivot columns | |
p = [] | |
h = 0 # Pivot Row | |
k = 0 # Pivot Column | |
# Convert to row echelon form | |
while h < E.nrows() and k < E.ncols(): | |
# Search for a row with the k-th pivot | |
pivot = next((i for i in range(h, E.nrows()) if E[i, k]), None) | |
# Check if we found the pivot | |
if pivot is not None: | |
p.append(k) | |
E.swap_rows(h, pivot) | |
T.swap_rows(h, pivot) | |
# Remove the pivot from the necessary rows | |
for i in range(0 if reduced else (h + 1), E.nrows()): | |
if i != h and E[i, k]: | |
E.add_row(i, h) | |
T.add_row(i, h) | |
h += 1 | |
elif full_rank: | |
break | |
k += 1 | |
assert E == T * self | |
return E, T, p | |
def rank(self): | |
''' | |
Returns the rank of self. | |
See `echelon_form` | |
References: | |
https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Rank_from_row_echelon_forms | |
''' | |
_, _, p = self.echelon_form(reduced=False) | |
return len(p) | |
def rref(self): | |
''' | |
Returns the reduced row echelon form of self. | |
See `echelon_form` | |
''' | |
E, _, _ = self.echelon_form(reduced=True) | |
return E | |
def CR(self): | |
''' | |
Computes the column-row factorization of self | |
Returns: | |
tuple ( | |
C: An m x r matrix: the independent columns of self (non-zero pivot columns of self) | |
R: An r x n matrix: the independent rows of self (non-zero rows of `self.rref()`) | |
) | |
Where `r = self.rank()`, `self == C * R` | |
References: | |
https://en.wikipedia.org/wiki/Rank_factorization | |
https://www.norbertwiener.umd.edu/FFT/2020/Faraway%20Slides/Faraway%20Strang.pdf | |
https://math.mit.edu/~gs/everyone/lucrweb.pdf | |
https://math.mit.edu/~gs/linearalgebra/lafe019 | |
''' | |
# Compute the reduced echelon form and pivots of self | |
E, _, p = self.echelon_form(reduced=True) | |
C = [] # Pivot Columns | |
R = [] # Reduced Rows | |
for h, k in enumerate(p): | |
C.append(self.column(k)) | |
R.append(E.row(h)) | |
C = Matrix(C, self.nrows()).transpose() | |
R = Matrix(R, self.ncols()) | |
assert self == C * R | |
return C, R | |
def solve_left(self, other): | |
''' | |
Returns a solution for x, where `x * self = other` | |
Requires self is invertible | |
''' | |
# TODO: Calculate left-inverse | |
return other * self.inverse() | |
def solve_right(self, other): | |
''' | |
Returns a solution for x, where `self * x = other` | |
''' | |
x = self.right_inverse() * other | |
if self * x != other: | |
raise ValueError('Inconsistent solution') | |
return x | |
# Return inverse of self | |
def inverse(self): | |
''' | |
Computes the inverse of self | |
Requires self is a square matrix of full rank | |
Returns: | |
The inverse of self | |
References: | |
https://en.wikipedia.org/wiki/Invertible_matrix | |
https://en.wikipedia.org/wiki/Gaussian_elimination#Finding_the_inverse_of_a_matrix | |
''' | |
if not self.is_square(): | |
raise ValueError(f'Matrix is not invertible ({self.nrows()} x {self.ncols()} is not square)') | |
E, T, p = self.echelon_form(reduced=True, full_rank=True) | |
rank = len(p) | |
# Check for full rank | |
if rank != self.ncols(): | |
raise ValueError(f'Matrix is not invertible (rank is not full: got {rank}, needed {self.ncols()})') | |
# The reduced echelon form should be the identity matrix | |
assert E.is_identity() | |
# The inverse should work both ways | |
assert (self * T).is_identity() | |
assert (T * self).is_identity() | |
return T | |
# Returns the right-inverse of self | |
def right_inverse(self): | |
''' | |
Computes the right-inverse of self | |
Returns: | |
The right-inverse of self | |
References: | |
https://en.wikipedia.org/wiki/Invertible_matrix | |
https://en.wikipedia.org/wiki/Gaussian_elimination#Finding_the_inverse_of_a_matrix | |
''' | |
_, T, p = self.echelon_form(reduced=True) | |
rank = len(p) | |
basis = zero_matrix(self.ncols(), self.nrows()) | |
for h, k in enumerate(p): | |
basis[k, h] = 1 | |
T = basis * T | |
return T | |
# Returns row [i] or entry [i, j] | |
def __getitem__(self, index): | |
if isinstance(index, tuple): | |
i, j = index | |
if j >= self.ncols(): | |
raise IndexError('Column index is out of range') | |
return (self._rows[i] >> j) & 1 | |
elif isinstance(index, int): | |
return self._rows[index] | |
elif isinstance(index, slice): | |
return Matrix(self._rows[index], self._ncols) | |
else: | |
raise TypeError('Matrix index must be an integer or integer pair') | |
# Sets row[i] or entry [i, j] | |
def __setitem__(self, index, value): | |
if isinstance(index, tuple): | |
i, j = index | |
if j >= self.ncols(): | |
raise IndexError('Column index is out of range') | |
if value not in [0, 1]: | |
raise ValueError('Matrix entry must be 0 or 1') | |
self._rows[i] = (self._rows[i] & ~(1 << j)) | (value << j) | |
elif isinstance(index, int): | |
if value >> self.ncols(): | |
raise ValueError('Row value is too large') | |
self._rows[index] = value | |
else: | |
raise TypeError('Matrix index must be an integer or integer pair') | |
# Returns an iterator over the rows of self (as integers) | |
def __iter__(self): | |
return self.iter_rows() | |
# Returns the number of rows | |
def __len__(self): | |
return self.nrows() | |
# Returns the result of self * other | |
def __mul__(self, other): | |
if isinstance(other, Matrix): | |
if self.ncols() != other.nrows(): | |
raise ValueError(f'Cannot multiply {self.ncols()} column matrix with {other.nrows()} row matrix') | |
return Matrix((row * other for row in self.rows()), other.ncols()) | |
elif isinstance(other, int): | |
if other >> self.ncols(): | |
raise ValueError('Row value is too large') | |
# return gf2_sum(self.column(j) for j in range(self.ncols()) if bit(other, j)) | |
return gf2_col_mul(self.iter_rows(), other) | |
else: | |
return NotImplemented | |
# Returns the result of other * self | |
def __rmul__(self, other): | |
if isinstance(other, int): | |
if other >> self.nrows(): | |
raise ValueError('Column value is too large') | |
# return gf2_sum(self.row(i) for i in range(self.nrows()) if bit(other, i)) | |
return gf2_row_mul(self.iter_rows(), other) | |
else: | |
return NotImplemented | |
def __eq__(self, other): | |
return (self._rows, self._ncols) == (other._rows, other._ncols) | |
def __ne__(self, other): | |
return not self == other | |
def __repr__(self): | |
return f'{self.nrows()} x {self.ncols()} matrix over GF(2)' | |
def __str__(self): | |
return '\n'.join('[{}]'.format(' '.join( | |
map(str, self.row_bits(i)) | |
)) for i in range(self.nrows())) | |
# Returns a dim x dim identity matrix | |
def identity_matrix(dim): | |
return Matrix([ (1 << i) for i in range(dim) ], dim) | |
# Returns a nrows x ncols zero matrix | |
def zero_matrix(nrows, ncols): | |
return Matrix([0] * nrows, ncols) | |
if __name__ == '__main__': | |
m = identity_matrix(4) | |
print(repr(m)) | |
print(m) | |
assert m.is_square() | |
assert m.is_identity() | |
assert m == m.transpose() | |
assert (m * m).is_identity() | |
# Perform some elementary row operations | |
m.add_row(0, 1) | |
m.add_row(2, 0) | |
m.swap_rows(3, 2) | |
# Left-multiply operates on rows, right-multiply operates on columns | |
assert (3 * m) == (m.transpose() * 3) | |
# Elementary row operations do not change the rank | |
assert m.rank() == m.nrows() | |
# Elementary row operations do not change the reduced echelon form | |
m = m.rref() | |
# The echelon form will still be the identity | |
assert m.is_identity() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment