Last active
July 21, 2023 20:11
-
-
Save sweeneyde/8c7e36d82ba0598147519fa8fa1392de to your computer and use it in GitHub Desktop.
Smith Normal Form Calculator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Matrix: | |
def __init__(self, arr, n=None): | |
if isinstance(arr, Matrix): | |
self.arr = [list(row) for row in arr.arr] | |
self.m = arr.m | |
self.n = arr.n | |
else: | |
self.arr = [list(row) for row in arr] | |
self.m = len(arr) | |
if not arr: | |
# Make sure we can have 0-by-n matrices | |
assert n is not None | |
self.n = n | |
else: | |
self.n = len(arr[0]) | |
assert n is None or n == self.n | |
assert all(len(row) == self.n for row in arr) | |
def __repr__(self): | |
return f"{type(self).__name__}({self.arr})" | |
def __str__(self): | |
ln = max(len(str(x)) for row in self.arr for x in row) | |
spec = " {:^" + str(ln) + "} " | |
return "\n".join( | |
"".join(spec.format(x) for x in row) | |
for row in self.arr | |
) | |
def __matmul__(self, other): | |
assert self.n == other.m | |
kk = self.n | |
A, B = self.arr, other.arr | |
return type(self)( | |
[ | |
[sum([A[i][k] * B[k][j] for k in range(kk)]) | |
for j in range(other.n)] | |
for i in range(self.m)], n=other.n) | |
def __eq__(self, other): | |
return self.arr == other.arr | |
@classmethod | |
def id(cls, n): | |
rn = range(n) | |
return cls([[int(i == j) for j in rn] for i in rn], n=n) | |
class Snf: | |
def __init__(self, *args, **kwargs): | |
self.D = Matrix(*args, **kwargs) | |
self.n = self.D.n | |
self.m = self.D.m | |
def smithify(self): | |
self._make_diagonal() | |
self._fix_divisibility() | |
ATTRIBUTES = ("D",) | |
def get_matrices(self): | |
return [getattr(self, attr) for attr in self.ATTRIBUTES] | |
def _make_diagonal(self): | |
D, m, n = self.D.arr, self.m, self.n | |
for k in range(min(m, n)): | |
while (any(D[i][k] for i in range(k + 1, m)) | |
or any(D[k][j] for j in range(k + 1, n))): | |
for i in range(k + 1, m): | |
self._improve_with_row_ops(k, i, k) | |
for j in range(k + 1, n): | |
self._improve_with_col_ops(k, j, k) | |
def _fix_divisibility(self): | |
D, m, n = self.D.arr, self.m, self.n | |
assert all(D[i][j] == 0 or i == j | |
for i in range(m) for j in range(n)) | |
for k in range(min(m, n) - 1): | |
# Start with block [ A 0 ] | |
# [ 0 B ] | |
self.col_op(k, k + 1, 1) | |
# Now have [ A 0 ] | |
# [ B B ] | |
self._improve_with_row_ops(k, k+1, k) | |
# Now have [ gcd(A,B) X ] | |
# [ 0 Y ] | |
# Because we used row operations, X and Y are multiples of B. | |
self._improve_with_col_ops(k, k+1, k) | |
# Since gcd(A,B) divides B which divides X, | |
# a single row operation suffices, and we have | |
# [ gcd(A, B) 0 ] | |
# [ 0 Y ] | |
def row_op(self, target_i, source_i, multiplier): | |
D = self.D.arr | |
for j in range(self.n): | |
D[target_i][j] += multiplier * D[source_i][j] | |
def col_op(self, target_j, source_j, multiplier): | |
D = self.D.arr | |
for i in range(self.m): | |
D[i][target_j] += multiplier * D[i][source_j] | |
def swap_rows(self, i1, i2): | |
# for j in range(n): | |
# A[i1][j], A[i2][j] = A[i2][j], A[i1][j] | |
D = self.D.arr | |
D[i1], D[i2] = D[i2], D[i1] | |
def swap_cols(self, j1, j2): | |
# for i in range(m): | |
# A[i][j1], A[i][j2] = A[i][j2], A[i][j1] | |
for row in self.D.arr: | |
row[j1], row[j2] = row[j2], row[j1] | |
def _improve_with_row_ops(self, i1, i2, j): | |
# Do a Euclidean algorithm on the two entries, | |
# carrying their rows along for the ride. | |
# It's important that if the A[i2][j] % A[i1][j] == 0 | |
# then the effect is a single row operation to kill A[i2][j]. | |
D = self.D.arr | |
if D[i1][j] == 0: | |
self.swap_rows(i1, i2) | |
return | |
while True: | |
q = D[i2][j] // D[i1][j] | |
self.row_op(i2, i1, -q) | |
if D[i2][j] == 0: | |
return | |
self.swap_rows(i1, i2) | |
def _improve_with_col_ops(self, j1, j2, i): | |
D = self.D.arr | |
if D[i][j1] == 0: | |
self.swap_cols(j1, j2) | |
return | |
while True: | |
q = D[i][j2] // D[i][j1] | |
self.col_op(j2, j1, -q) | |
if D[i][j2] == 0: | |
return | |
self.swap_cols(j1, j2) | |
class SnfWithL(Snf): | |
ATTRIBUTES = ("L", "D") | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
rm = range(self.m) | |
self.L = Matrix([[int(i == j) for j in rm] for i in rm], n=self.m) | |
def row_op(self, target_i, source_i, multiplier): | |
super().row_op(target_i, source_i, multiplier) | |
L = self.L.arr | |
for i in range(self.m): | |
L[i][source_i] -= L[i][target_i] * multiplier | |
def swap_rows(self, i1, i2): | |
super().swap_rows(i1, i2) | |
# Swap the columns of L | |
for row in self.L.arr: | |
row[i1], row[i2] = row[i2], row[i1] | |
class SnfWithR(Snf): | |
ATTRIBUTES = ("D", "R") | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
rn = range(self.n) | |
self.R = Matrix([[int(i == j) for j in rn] for i in rn], n=self.n) | |
def col_op(self, target_j, source_j, multiplier): | |
super().col_op(target_j, source_j, multiplier) | |
R = self.R.arr | |
for j in range(self.n): | |
R[source_j][j] -= R[target_j][j] * multiplier | |
def swap_cols(self, j1, j2): | |
super().swap_cols(j1, j2) | |
# Swap the rows of R | |
R = self.R.arr | |
R[j1], R[j2] = R[j2], R[j1] | |
class SnfWithLinv(Snf): | |
ATTRIBUTES = ("Linv", "D") | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
rm = range(self.m) | |
self.Linv = Matrix([[int(i == j) for j in rm] for i in rm], n=self.m) | |
def row_op(self, target_i, source_i, multiplier): | |
super().row_op(target_i, source_i, multiplier) | |
Linv = self.Linv.arr | |
for j in range(self.m): | |
Linv[target_i][j] += multiplier * Linv[source_i][j] | |
def swap_rows(self, i1, i2): | |
super().swap_rows(i1, i2) | |
Linv = self.Linv.arr | |
Linv[i1], Linv[i2] = Linv[i2], Linv[i1] | |
class SnfWithRinv(Snf): | |
ATTRIBUTES = ("D", "Rinv") | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
rn = range(self.n) | |
self.Rinv = Matrix([[int(i == j) for j in rn] for i in rn], n=self.n) | |
def col_op(self, target_j, source_j, multiplier): | |
super().col_op(target_j, source_j, multiplier) | |
Rinv = self.Rinv.arr | |
for i in range(self.n): | |
Rinv[i][target_j] += multiplier * Rinv[i][source_j] | |
def swap_cols(self, j1, j2): | |
super().swap_cols(j1, j2) | |
for row in self.Rinv.arr: | |
row[j1], row[j2] = row[j2], row[j1] | |
class SnfWithAll(SnfWithL, SnfWithR, SnfWithLinv, SnfWithRinv): | |
ATTRIBUTES = ("Linv", "L", "D", "R", "Rinv") | |
def test(): | |
from random import randrange | |
from itertools import product | |
def check_SnfWithAll(X): | |
m, n = X.m, X.n | |
S = SnfWithAll(X) | |
S.smithify() | |
Linv, L, D, R, Rinv = S.get_matrices() | |
id_m, id_n = Matrix.id(m), Matrix.id(n) | |
assert X == L @ D @ R | |
assert Linv @ X @ Rinv == D | |
assert L @ Linv == id_m | |
assert Linv @ L == id_m | |
assert R @ Rinv == id_n | |
assert Rinv @ R == id_n | |
# assert diagonal | |
for i, row in enumerate(D.arr): | |
for j, x in enumerate(row): | |
if i != j: | |
assert x == 0 | |
def check_individual_components(X): | |
Snf(X).smithify() | |
SnfWithL(X).smithify() | |
SnfWithR(X).smithify() | |
SnfWithLinv(X).smithify() | |
SnfWithRinv(X).smithify() | |
# check all zero-one matrices up to 3-by-4 | |
for m in range(1, 3+1): | |
for n in range(1, 4+1): | |
starts = range(0, m*n, n) if n else () | |
for data in product((0,1), repeat=m*n): | |
X = Matrix([data[start:start+n] for start in starts], n=n) | |
check_SnfWithAll(X) | |
check_individual_components(X) | |
print("deterministic tests pass!!") | |
for case in range(10_000): | |
m = randrange(10) | |
n = randrange(10) | |
X = Matrix([[randrange(-100, 100) for _ in range(n)] for _ in range(m)], n=n) | |
check_SnfWithAll(X) | |
check_individual_components(X) | |
if case % 1000 == 0: | |
print(f"{case // 1000}k") | |
print("random tests pass!!") | |
def get_kernel_basis(X): | |
# Example: Suppose we need to solve for X*(x,y,z) = 0, | |
# where X = [(2, 0, 1), (0, 2, 3)]. | |
# Write X = L*D*R with L, R invertible, D diagonal. | |
# L = [(1,0),(1,1)], D = [(1,0,0),(0,2,0)], R = [(2,0,-1),(1,1,-2),(1,0,-1)] | |
# Then L*D*R*(x,y,z)=X*(x,y,z)=0 iff D*R*(x,y,z) == 0. | |
# First solve for D*(x',y',z')=0, where (x',y',z')=R*(x,y,z) | |
# (1 0 0) (x') (0) | |
# (0 2 0) * (y') == (0) | |
# (z') | |
# This has a basis of (0,0,1). | |
# Generally, we get a 1 for each zero-column. | |
# Now R*(x,y,z)=(x',y',z')=(0,0,1) iff (x,y,z)=Rinv*(0,0,1) | |
# so pick out the last column of Rinv. | |
S = SnfWithRinv(X) | |
S.smithify() | |
D, Rinv = S.get_matrices() | |
D = D.arr | |
Rinv = Rinv.arr | |
m, n = S.m, S.n | |
indices_of_zero_columns = [j for j in range(n) if j >= m or D[j][j] == 0] | |
return [ | |
[Rinv[i][j] for i in range(n)] | |
for j in indices_of_zero_columns | |
] | |
def balance_chemical_equation(compounds): | |
# For x(N_2) + y(H_2) --> z(NH_3) | |
# compounds will be [{'N': 2}, {'H': 2}, {'N': -1, 'H': -3}] | |
# This gives equations: | |
# 2x + 0y + 1z = 0 | |
# 0x + 2y + 3z = 0 | |
# Matrix is X = [(2, 0, 1), (0, 2, 3)], so X * (x,y,z)=0 | |
# Need to find its kernel. | |
all_elements = {} | |
for c in compounds: | |
all_elements |= c | |
equations = [ | |
[c.get(element, 0) for c in compounds] | |
for element in all_elements | |
] | |
basis = get_kernel_basis(Matrix(equations, n=len(compounds))) | |
basis = [ | |
[-x for x in vec] if sum(x//abs(x) for x in vec if x) < 0 else vec | |
for vec in basis | |
] | |
return basis | |
def balance_chemical_equation_string(text): | |
text = text.strip() | |
while " " in text: | |
text = text.replace(" ", " ") | |
left, right = text.split(' --> ') | |
left = left.split(" + ") | |
right = right.split(" + ") | |
def parse(word): | |
capitals = [i for i, c in enumerate(word) if c.isupper()] | |
capitals.append(len(word)) | |
element_strings = [word[a:b] for a, b in zip(capitals, capitals[1:])] | |
compound = {} | |
for element_string in element_strings: | |
name, sub, count = element_string.partition("_") | |
count = int(count) if sub else 1 | |
compound[name] = compound.get(name, 0) + count | |
return compound | |
left_compounds = [parse(word) for word in left] | |
right_compounds = [parse(word) for word in right] | |
compounds = left_compounds + [{k:-v for k, v in c.items()} for c in right_compounds] | |
basis = balance_chemical_equation(compounds) | |
output = [] | |
for vec in basis: | |
lhs = " + ".join(f"{coeff}*({compname})" for (coeff, compname) in zip(vec, left) if coeff) | |
rhs = " + ".join(f"{coeff}*({compname})" for (coeff, compname) in zip(vec[len(left):], right) if coeff) | |
output.append(f"{lhs} --> {rhs}") | |
return output | |
def interactive_equation_balancer(): | |
print("Enter a chemical equation without coefficients.") | |
print('example: H_2 + N_2 --> NH_3') | |
print() | |
while True: | |
eqn = input("equation: ") | |
if not eqn: | |
continue | |
results = balance_chemical_equation_string(eqn) | |
if not results: | |
print("No nontrivial ways to balance.") | |
elif len(results) == 1: | |
print("One way to balance:") | |
else: | |
print(f"{len(results)} ways to balance:") | |
for res in results: | |
print(" " + res) | |
print() | |
if __name__ == "__main__": | |
test() | |
interactive_equation_balancer() | |
# basis = balance_chemical_equation([{'N': 2}, {'H': 2}, {'N': -1, 'H': -3}]) | |
# print(basis) | |
# S = SnfWithAll([[2, 0, -1], [0, 2, -3]]) | |
# S.smithify() | |
# Linv, L, D, R, Rinv = S.get_matrices() | |
# print(f"Linv=\n{Linv}\n\nL=\n{L}\n\nD=\n{D}\n\nR=\n{R}\n\nRinv=\n{Rinv}\n\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment