Skip to content

Instantly share code, notes, and snippets.

@sweeneyde
Last active July 21, 2023 20:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sweeneyde/8c7e36d82ba0598147519fa8fa1392de to your computer and use it in GitHub Desktop.
Save sweeneyde/8c7e36d82ba0598147519fa8fa1392de to your computer and use it in GitHub Desktop.
Smith Normal Form Calculator
class Matrix:
def __init__(self, arr, n=None):
if isinstance(arr, Matrix):
self.arr = [list(row) for row in arr.arr]
self.m = arr.m
self.n = arr.n
else:
self.arr = [list(row) for row in arr]
self.m = len(arr)
if not arr:
# Make sure we can have 0-by-n matrices
assert n is not None
self.n = n
else:
self.n = len(arr[0])
assert n is None or n == self.n
assert all(len(row) == self.n for row in arr)
def __repr__(self):
return f"{type(self).__name__}({self.arr})"
def __str__(self):
ln = max(len(str(x)) for row in self.arr for x in row)
spec = " {:^" + str(ln) + "} "
return "\n".join(
"".join(spec.format(x) for x in row)
for row in self.arr
)
def __matmul__(self, other):
assert self.n == other.m
kk = self.n
A, B = self.arr, other.arr
return type(self)(
[
[sum([A[i][k] * B[k][j] for k in range(kk)])
for j in range(other.n)]
for i in range(self.m)], n=other.n)
def __eq__(self, other):
return self.arr == other.arr
@classmethod
def id(cls, n):
rn = range(n)
return cls([[int(i == j) for j in rn] for i in rn], n=n)
class Snf:
def __init__(self, *args, **kwargs):
self.D = Matrix(*args, **kwargs)
self.n = self.D.n
self.m = self.D.m
def smithify(self):
self._make_diagonal()
self._fix_divisibility()
ATTRIBUTES = ("D",)
def get_matrices(self):
return [getattr(self, attr) for attr in self.ATTRIBUTES]
def _make_diagonal(self):
D, m, n = self.D.arr, self.m, self.n
for k in range(min(m, n)):
while (any(D[i][k] for i in range(k + 1, m))
or any(D[k][j] for j in range(k + 1, n))):
for i in range(k + 1, m):
self._improve_with_row_ops(k, i, k)
for j in range(k + 1, n):
self._improve_with_col_ops(k, j, k)
def _fix_divisibility(self):
D, m, n = self.D.arr, self.m, self.n
assert all(D[i][j] == 0 or i == j
for i in range(m) for j in range(n))
for k in range(min(m, n) - 1):
# Start with block [ A 0 ]
# [ 0 B ]
self.col_op(k, k + 1, 1)
# Now have [ A 0 ]
# [ B B ]
self._improve_with_row_ops(k, k+1, k)
# Now have [ gcd(A,B) X ]
# [ 0 Y ]
# Because we used row operations, X and Y are multiples of B.
self._improve_with_col_ops(k, k+1, k)
# Since gcd(A,B) divides B which divides X,
# a single row operation suffices, and we have
# [ gcd(A, B) 0 ]
# [ 0 Y ]
def row_op(self, target_i, source_i, multiplier):
D = self.D.arr
for j in range(self.n):
D[target_i][j] += multiplier * D[source_i][j]
def col_op(self, target_j, source_j, multiplier):
D = self.D.arr
for i in range(self.m):
D[i][target_j] += multiplier * D[i][source_j]
def swap_rows(self, i1, i2):
# for j in range(n):
# A[i1][j], A[i2][j] = A[i2][j], A[i1][j]
D = self.D.arr
D[i1], D[i2] = D[i2], D[i1]
def swap_cols(self, j1, j2):
# for i in range(m):
# A[i][j1], A[i][j2] = A[i][j2], A[i][j1]
for row in self.D.arr:
row[j1], row[j2] = row[j2], row[j1]
def _improve_with_row_ops(self, i1, i2, j):
# Do a Euclidean algorithm on the two entries,
# carrying their rows along for the ride.
# It's important that if the A[i2][j] % A[i1][j] == 0
# then the effect is a single row operation to kill A[i2][j].
D = self.D.arr
if D[i1][j] == 0:
self.swap_rows(i1, i2)
return
while True:
q = D[i2][j] // D[i1][j]
self.row_op(i2, i1, -q)
if D[i2][j] == 0:
return
self.swap_rows(i1, i2)
def _improve_with_col_ops(self, j1, j2, i):
D = self.D.arr
if D[i][j1] == 0:
self.swap_cols(j1, j2)
return
while True:
q = D[i][j2] // D[i][j1]
self.col_op(j2, j1, -q)
if D[i][j2] == 0:
return
self.swap_cols(j1, j2)
class SnfWithL(Snf):
ATTRIBUTES = ("L", "D")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
rm = range(self.m)
self.L = Matrix([[int(i == j) for j in rm] for i in rm], n=self.m)
def row_op(self, target_i, source_i, multiplier):
super().row_op(target_i, source_i, multiplier)
L = self.L.arr
for i in range(self.m):
L[i][source_i] -= L[i][target_i] * multiplier
def swap_rows(self, i1, i2):
super().swap_rows(i1, i2)
# Swap the columns of L
for row in self.L.arr:
row[i1], row[i2] = row[i2], row[i1]
class SnfWithR(Snf):
ATTRIBUTES = ("D", "R")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
rn = range(self.n)
self.R = Matrix([[int(i == j) for j in rn] for i in rn], n=self.n)
def col_op(self, target_j, source_j, multiplier):
super().col_op(target_j, source_j, multiplier)
R = self.R.arr
for j in range(self.n):
R[source_j][j] -= R[target_j][j] * multiplier
def swap_cols(self, j1, j2):
super().swap_cols(j1, j2)
# Swap the rows of R
R = self.R.arr
R[j1], R[j2] = R[j2], R[j1]
class SnfWithLinv(Snf):
ATTRIBUTES = ("Linv", "D")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
rm = range(self.m)
self.Linv = Matrix([[int(i == j) for j in rm] for i in rm], n=self.m)
def row_op(self, target_i, source_i, multiplier):
super().row_op(target_i, source_i, multiplier)
Linv = self.Linv.arr
for j in range(self.m):
Linv[target_i][j] += multiplier * Linv[source_i][j]
def swap_rows(self, i1, i2):
super().swap_rows(i1, i2)
Linv = self.Linv.arr
Linv[i1], Linv[i2] = Linv[i2], Linv[i1]
class SnfWithRinv(Snf):
ATTRIBUTES = ("D", "Rinv")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
rn = range(self.n)
self.Rinv = Matrix([[int(i == j) for j in rn] for i in rn], n=self.n)
def col_op(self, target_j, source_j, multiplier):
super().col_op(target_j, source_j, multiplier)
Rinv = self.Rinv.arr
for i in range(self.n):
Rinv[i][target_j] += multiplier * Rinv[i][source_j]
def swap_cols(self, j1, j2):
super().swap_cols(j1, j2)
for row in self.Rinv.arr:
row[j1], row[j2] = row[j2], row[j1]
class SnfWithAll(SnfWithL, SnfWithR, SnfWithLinv, SnfWithRinv):
ATTRIBUTES = ("Linv", "L", "D", "R", "Rinv")
def test():
from random import randrange
from itertools import product
def check_SnfWithAll(X):
m, n = X.m, X.n
S = SnfWithAll(X)
S.smithify()
Linv, L, D, R, Rinv = S.get_matrices()
id_m, id_n = Matrix.id(m), Matrix.id(n)
assert X == L @ D @ R
assert Linv @ X @ Rinv == D
assert L @ Linv == id_m
assert Linv @ L == id_m
assert R @ Rinv == id_n
assert Rinv @ R == id_n
# assert diagonal
for i, row in enumerate(D.arr):
for j, x in enumerate(row):
if i != j:
assert x == 0
def check_individual_components(X):
Snf(X).smithify()
SnfWithL(X).smithify()
SnfWithR(X).smithify()
SnfWithLinv(X).smithify()
SnfWithRinv(X).smithify()
# check all zero-one matrices up to 3-by-4
for m in range(1, 3+1):
for n in range(1, 4+1):
starts = range(0, m*n, n) if n else ()
for data in product((0,1), repeat=m*n):
X = Matrix([data[start:start+n] for start in starts], n=n)
check_SnfWithAll(X)
check_individual_components(X)
print("deterministic tests pass!!")
for case in range(10_000):
m = randrange(10)
n = randrange(10)
X = Matrix([[randrange(-100, 100) for _ in range(n)] for _ in range(m)], n=n)
check_SnfWithAll(X)
check_individual_components(X)
if case % 1000 == 0:
print(f"{case // 1000}k")
print("random tests pass!!")
def get_kernel_basis(X):
# Example: Suppose we need to solve for X*(x,y,z) = 0,
# where X = [(2, 0, 1), (0, 2, 3)].
# Write X = L*D*R with L, R invertible, D diagonal.
# L = [(1,0),(1,1)], D = [(1,0,0),(0,2,0)], R = [(2,0,-1),(1,1,-2),(1,0,-1)]
# Then L*D*R*(x,y,z)=X*(x,y,z)=0 iff D*R*(x,y,z) == 0.
# First solve for D*(x',y',z')=0, where (x',y',z')=R*(x,y,z)
# (1 0 0) (x') (0)
# (0 2 0) * (y') == (0)
# (z')
# This has a basis of (0,0,1).
# Generally, we get a 1 for each zero-column.
# Now R*(x,y,z)=(x',y',z')=(0,0,1) iff (x,y,z)=Rinv*(0,0,1)
# so pick out the last column of Rinv.
S = SnfWithRinv(X)
S.smithify()
D, Rinv = S.get_matrices()
D = D.arr
Rinv = Rinv.arr
m, n = S.m, S.n
indices_of_zero_columns = [j for j in range(n) if j >= m or D[j][j] == 0]
return [
[Rinv[i][j] for i in range(n)]
for j in indices_of_zero_columns
]
def balance_chemical_equation(compounds):
# For x(N_2) + y(H_2) --> z(NH_3)
# compounds will be [{'N': 2}, {'H': 2}, {'N': -1, 'H': -3}]
# This gives equations:
# 2x + 0y + 1z = 0
# 0x + 2y + 3z = 0
# Matrix is X = [(2, 0, 1), (0, 2, 3)], so X * (x,y,z)=0
# Need to find its kernel.
all_elements = {}
for c in compounds:
all_elements |= c
equations = [
[c.get(element, 0) for c in compounds]
for element in all_elements
]
basis = get_kernel_basis(Matrix(equations, n=len(compounds)))
basis = [
[-x for x in vec] if sum(x//abs(x) for x in vec if x) < 0 else vec
for vec in basis
]
return basis
def balance_chemical_equation_string(text):
text = text.strip()
while " " in text:
text = text.replace(" ", " ")
left, right = text.split(' --> ')
left = left.split(" + ")
right = right.split(" + ")
def parse(word):
capitals = [i for i, c in enumerate(word) if c.isupper()]
capitals.append(len(word))
element_strings = [word[a:b] for a, b in zip(capitals, capitals[1:])]
compound = {}
for element_string in element_strings:
name, sub, count = element_string.partition("_")
count = int(count) if sub else 1
compound[name] = compound.get(name, 0) + count
return compound
left_compounds = [parse(word) for word in left]
right_compounds = [parse(word) for word in right]
compounds = left_compounds + [{k:-v for k, v in c.items()} for c in right_compounds]
basis = balance_chemical_equation(compounds)
output = []
for vec in basis:
lhs = " + ".join(f"{coeff}*({compname})" for (coeff, compname) in zip(vec, left) if coeff)
rhs = " + ".join(f"{coeff}*({compname})" for (coeff, compname) in zip(vec[len(left):], right) if coeff)
output.append(f"{lhs} --> {rhs}")
return output
def interactive_equation_balancer():
print("Enter a chemical equation without coefficients.")
print('example: H_2 + N_2 --> NH_3')
print()
while True:
eqn = input("equation: ")
if not eqn:
continue
results = balance_chemical_equation_string(eqn)
if not results:
print("No nontrivial ways to balance.")
elif len(results) == 1:
print("One way to balance:")
else:
print(f"{len(results)} ways to balance:")
for res in results:
print(" " + res)
print()
if __name__ == "__main__":
test()
interactive_equation_balancer()
# basis = balance_chemical_equation([{'N': 2}, {'H': 2}, {'N': -1, 'H': -3}])
# print(basis)
# S = SnfWithAll([[2, 0, -1], [0, 2, -3]])
# S.smithify()
# Linv, L, D, R, Rinv = S.get_matrices()
# print(f"Linv=\n{Linv}\n\nL=\n{L}\n\nD=\n{D}\n\nR=\n{R}\n\nRinv=\n{Rinv}\n\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment