Skip to content

Instantly share code, notes, and snippets.

@0x1F9F1
Last active March 1, 2023 09:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 0x1F9F1/fcd3095a6fe56a323d41e13fdbc7bdb9 to your computer and use it in GitHub Desktop.
Save 0x1F9F1/fcd3095a6fe56a323d41e13fdbc7bdb9 to your computer and use it in GitHub Desktop.
Python class for performing matrix operations over GF(2)
# https://wiki.sagemath.org/quickref?action=AttachFile&do=get&target=quickref-linalg.pdf
# https://github.com/emin63/pyfinite
# https://github.com/malb/m4ri
from functools import reduce
import sys
def extract_rows(table):
'''
Extracts a set linearly independent rows from a lookup table
'''
# Pick any set of linearly independent rows
pending = table.copy()
pending.remove(0)
# We could start picking these in any order, however doing it this way preserves the original order if no s-box is required
# pending.sort()
# random.shuffle(pending)
seen = [0]
rows = []
while pending:
# For each bit of the input, pick a value from pending
row = pending[0]
rows.append(row)
# Update the seen using this new row
for v in seen[:]:
v ^= row
seen.append(v)
pending.remove(v)
rows = Matrix(rows)
# If the seen is not what we started with, perform row reduction
if seen != table:
rows = rows.rref()
return rows
def expand_rows(rows):
return [ i * rows for i in range(1 << rows.nrows()) ]
# def bit(v, index):
# return (v >> index) & 1
# def gf2_sum(values):
# return reduce(lambda x, y: x ^ y, values, 0)
def to_bits(value, count):
return ((value >> i) & 1 for i in range(count))
def from_bits(bits):
result = 0
for i, bit in enumerate(bits):
result |= bit << i
return result
def gf2_row_mul(rows, value):
result = 0
for i, row in enumerate(rows):
if (value >> i) & 1:
result ^= row
return result
if sys.version_info >= (3,10,0):
def parity(value):
return value.bit_count() & 1
else:
def parity(value):
i = 1
while True:
v = value >> i
if not v:
break
value ^= v
i <<= 1
return value & 1
def gf2_col_mul(rows, value):
result = 0
for i, row in enumerate(rows):
result |= parity(row & value) << i
return result
# A matrix over GF(2)
# Each row is represented using an `int`
# Designed to mostly mirror sage-math
class Matrix:
__slots__ = ['_rows', '_ncols']
def __init__(self, rows, ncols=None):
'''
Constructs a GF(2) matrix
Args:
rows: A sequence of integers representing the rows of the matrix
ncols: The number of columns (bits) in each row. Defaults to the minimum number of bits required to represent the largest row
'''
self._rows = list(rows)
nbits = max(self._rows).bit_length()
if ncols is None:
ncols = nbits
elif ncols < nbits:
raise ValueError(f'Asked for {ncols} columns, but need at least {nbits}')
self._ncols = ncols
# Returns the number of rows
def nrows(self):
return len(self._rows)
# Returns the number of columns
def ncols(self):
return self._ncols
# Returns a tuple containing the size (nrows, ncols)
def size(self):
return self.nrows(), self.ncols()
# Returns a list of the rows
def rows(self):
return list(self._rows)
# Returns an iterator over the rows
def iter_rows(self):
return iter(self._rows)
# Returns the row at i
def row(self, i):
return self._rows[i]
# Returns the row at i, as a generator of bits
def row_bits(self, i):
return to_bits(self.row(i), self.ncols())
# Returns a list of the columns
def columns(self):
return list(self.iter_columns())
# Returns a iterator over the columns
def iter_columns(self):
return (self.column(j) for j in range(self.ncols()))
# Returns the column at j
def column(self, j):
return from_bits(self.column_bits(j))
# Returns the column at j, as a generator of bits
def column_bits(self, j):
if j >= self.ncols():
raise IndexError('Column index is out of range')
return ((row >> j) & 1 for row in self.iter_rows())
# Returns a copy of self
def copy(self):
return Matrix(self.iter_rows(), self.ncols())
# Returns a transposed copy of self
def transpose(self):
return Matrix(self.columns(), self.nrows())
# Returns whether self is square
def is_square(self):
return self.nrows() == self.ncols()
# Returns whether self is identity
def is_identity(self):
return self.is_square() and all(row == (1 << i) for i, row in enumerate(self.iter_rows()))
# Adds rows j to row i
def add_row(self, i, j):
self._rows[i] ^= self._rows[j]
# Swaps the values of rows i and j
def swap_rows(self, i, j):
self._rows[i], self._rows[j] = self._rows[j], self._rows[i]
def echelon_form(self, reduced=True, full_rank=False):
'''
Computes the row echelon form of self.
Args:
reduced: Whether to convert to reduced row echelon form
full_rank: Whether to halt early if a pivot is missing
Returns:
tuple (
E: The row echelon form of self
T: An m by m matrix: `E == T * self`
p: Indices of the non-zero pivot columns of E: `rank = len(p)`
)
References:
https://en.wikipedia.org/wiki/Row_echelon_form#Reduced_row_echelon_form
https://en.wikipedia.org/wiki/Invertible_matrix#Gaussian_elimination
https://uk.mathworks.com/help/matlab/ref/rref.html
'''
# Row reduction is performed on E
E = self.copy()
# The transform performed on self to produce E
T = identity_matrix(E.nrows())
# The indices of the non-zero pivot columns
p = []
h = 0 # Pivot Row
k = 0 # Pivot Column
# Convert to row echelon form
while h < E.nrows() and k < E.ncols():
# Search for a row with the k-th pivot
pivot = next((i for i in range(h, E.nrows()) if E[i, k]), None)
# Check if we found the pivot
if pivot is not None:
p.append(k)
E.swap_rows(h, pivot)
T.swap_rows(h, pivot)
# Remove the pivot from the necessary rows
for i in range(0 if reduced else (h + 1), E.nrows()):
if i != h and E[i, k]:
E.add_row(i, h)
T.add_row(i, h)
h += 1
elif full_rank:
break
k += 1
assert E == T * self
return E, T, p
def rank(self):
'''
Returns the rank of self.
See `echelon_form`
References:
https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Rank_from_row_echelon_forms
'''
_, _, p = self.echelon_form(reduced=False)
return len(p)
def rref(self):
'''
Returns the reduced row echelon form of self.
See `echelon_form`
'''
E, _, _ = self.echelon_form(reduced=True)
return E
def CR(self):
'''
Computes the column-row factorization of self
Returns:
tuple (
C: An m x r matrix: the independent columns of self (non-zero pivot columns of self)
R: An r x n matrix: the independent rows of self (non-zero rows of `self.rref()`)
)
Where `r = self.rank()`, `self == C * R`
References:
https://en.wikipedia.org/wiki/Rank_factorization
https://www.norbertwiener.umd.edu/FFT/2020/Faraway%20Slides/Faraway%20Strang.pdf
https://math.mit.edu/~gs/everyone/lucrweb.pdf
https://math.mit.edu/~gs/linearalgebra/lafe019
'''
# Compute the reduced echelon form and pivots of self
E, _, p = self.echelon_form(reduced=True)
C = [] # Pivot Columns
R = [] # Reduced Rows
for h, k in enumerate(p):
C.append(self.column(k))
R.append(E.row(h))
C = Matrix(C, self.nrows()).transpose()
R = Matrix(R, self.ncols())
assert self == C * R
return C, R
def solve_left(self, other):
'''
Returns a solution for x, where `x * self = other`
Requires self is invertible
'''
# TODO: Calculate left-inverse
return other * self.inverse()
def solve_right(self, other):
'''
Returns a solution for x, where `self * x = other`
'''
x = self.right_inverse() * other
if self * x != other:
raise ValueError('Inconsistent solution')
return x
# Return inverse of self
def inverse(self):
'''
Computes the inverse of self
Requires self is a square matrix of full rank
Returns:
The inverse of self
References:
https://en.wikipedia.org/wiki/Invertible_matrix
https://en.wikipedia.org/wiki/Gaussian_elimination#Finding_the_inverse_of_a_matrix
'''
if not self.is_square():
raise ValueError(f'Matrix is not invertible ({self.nrows()} x {self.ncols()} is not square)')
E, T, p = self.echelon_form(reduced=True, full_rank=True)
rank = len(p)
# Check for full rank
if rank != self.ncols():
raise ValueError(f'Matrix is not invertible (rank is not full: got {rank}, needed {self.ncols()})')
# The reduced echelon form should be the identity matrix
assert E.is_identity()
# The inverse should work both ways
assert (self * T).is_identity()
assert (T * self).is_identity()
return T
# Returns the right-inverse of self
def right_inverse(self):
'''
Computes the right-inverse of self
Returns:
The right-inverse of self
References:
https://en.wikipedia.org/wiki/Invertible_matrix
https://en.wikipedia.org/wiki/Gaussian_elimination#Finding_the_inverse_of_a_matrix
'''
_, T, p = self.echelon_form(reduced=True)
rank = len(p)
basis = zero_matrix(self.ncols(), self.nrows())
for h, k in enumerate(p):
basis[k, h] = 1
T = basis * T
return T
# Returns row [i] or entry [i, j]
def __getitem__(self, index):
if isinstance(index, tuple):
i, j = index
if j >= self.ncols():
raise IndexError('Column index is out of range')
return (self._rows[i] >> j) & 1
elif isinstance(index, int):
return self._rows[index]
elif isinstance(index, slice):
return Matrix(self._rows[index], self._ncols)
else:
raise TypeError('Matrix index must be an integer or integer pair')
# Sets row[i] or entry [i, j]
def __setitem__(self, index, value):
if isinstance(index, tuple):
i, j = index
if j >= self.ncols():
raise IndexError('Column index is out of range')
if value not in [0, 1]:
raise ValueError('Matrix entry must be 0 or 1')
self._rows[i] = (self._rows[i] & ~(1 << j)) | (value << j)
elif isinstance(index, int):
if value >> self.ncols():
raise ValueError('Row value is too large')
self._rows[index] = value
else:
raise TypeError('Matrix index must be an integer or integer pair')
# Returns an iterator over the rows of self (as integers)
def __iter__(self):
return self.iter_rows()
# Returns the number of rows
def __len__(self):
return self.nrows()
# Returns the result of self * other
def __mul__(self, other):
if isinstance(other, Matrix):
if self.ncols() != other.nrows():
raise ValueError(f'Cannot multiply {self.ncols()} column matrix with {other.nrows()} row matrix')
return Matrix((row * other for row in self.rows()), other.ncols())
elif isinstance(other, int):
if other >> self.ncols():
raise ValueError('Row value is too large')
# return gf2_sum(self.column(j) for j in range(self.ncols()) if bit(other, j))
return gf2_col_mul(self.iter_rows(), other)
else:
return NotImplemented
# Returns the result of other * self
def __rmul__(self, other):
if isinstance(other, int):
if other >> self.nrows():
raise ValueError('Column value is too large')
# return gf2_sum(self.row(i) for i in range(self.nrows()) if bit(other, i))
return gf2_row_mul(self.iter_rows(), other)
else:
return NotImplemented
def __eq__(self, other):
return (self._rows, self._ncols) == (other._rows, other._ncols)
def __ne__(self, other):
return not self == other
def __repr__(self):
return f'{self.nrows()} x {self.ncols()} matrix over GF(2)'
def __str__(self):
return '\n'.join('[{}]'.format(' '.join(
map(str, self.row_bits(i))
)) for i in range(self.nrows()))
# Returns a dim x dim identity matrix
def identity_matrix(dim):
return Matrix([ (1 << i) for i in range(dim) ], dim)
# Returns a nrows x ncols zero matrix
def zero_matrix(nrows, ncols):
return Matrix([0] * nrows, ncols)
if __name__ == '__main__':
m = identity_matrix(4)
print(repr(m))
print(m)
assert m.is_square()
assert m.is_identity()
assert m == m.transpose()
assert (m * m).is_identity()
# Perform some elementary row operations
m.add_row(0, 1)
m.add_row(2, 0)
m.swap_rows(3, 2)
# Left-multiply operates on rows, right-multiply operates on columns
assert (3 * m) == (m.transpose() * 3)
# Elementary row operations do not change the rank
assert m.rank() == m.nrows()
# Elementary row operations do not change the reduced echelon form
m = m.rref()
# The echelon form will still be the identity
assert m.is_identity()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment