Skip to content

Instantly share code, notes, and snippets.

@mafrasi2
Last active August 29, 2015 14:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mafrasi2/14e7ff1d195ebc832d40 to your computer and use it in GitHub Desktop.
Save mafrasi2/14e7ff1d195ebc832d40 to your computer and use it in GitHub Desktop.
from math import sqrt
from itertools import count, islice, zip_longest
import sys
from pprint import pprint
def isPrime(n):
if n < 2: return False
return all(n%i for i in islice(count(2), int(sqrt(n)-1)))
class MalformedTransversalError(Exception):
pass
class Transversal:
def __init__(self, t, prime):
self.prime = prime
self.t_dict = dict()
if not isPrime(prime):
raise ValueError("Expected prime, got {}".format(prime))
needed = list(range(prime))
for i in t:
residue = i % prime
if residue in self.t_dict:
raise MalformedTransversalError("Doubled representative.")
self.t_dict[residue] = i
needed.remove(residue)
if len(needed) != 0:
raise MalformedTransversalError("Representative missing.")
def get_repr(self, n):
return self.t_dict[n % self.prime]
class StdTransversal(Transversal):
def __init__(self, prime):
self.prime = prime
def get_repr(self, n):
return n % self.prime
class MalformedMatrixError(Exception):
pass
class ResFieldMatrix:
def __init__(self, m, prime):
self.m = m
self.prime = prime
self.width, self.height = ResFieldMatrix.__check_matrix(m)
def transpose(self):
# zip the rows
transposed = zip(*self.m)
transposed = list(transposed)
return ResFieldMatrix(transposed, self.prime)
def to_transversal(self, transversal=None):
"""Convert the matrix to a given transversal or the std transversal"""
if transversal == None:
transversal = StdTransversal(self.prime)
elif transversal.prime != self.prime:
raise ValueError("Expected transversal for prime {}, got {}".format(self.prime, transversal.prime))
return [[transversal.get_repr(ele) for ele in row] for row in self.m]
@staticmethod
def __check_matrix(m):
width = None
for r in m:
if width == None:
width = len(r)
elif width != len(r):
raise MalformedMatrixError("#TODO")
return width, len(m)
def __mul__(self, other):
if isinstance(other, int):
for r in range(self.height):
for c in range(self.width):
self.m[r][c] = (other * self.m[r][c]) % self.prime
elif isinstance(other, ResFieldMatrix):
if other.prime != self.prime:
raise ArithmeticError("Tried to multiplicate elements of different residue fields.")
if self.width != other.height:
raise ArithmeticError("Tried to multiplicate matrices with unfitting dimensions.")
tr_other = other.transpose()
# the columns of other are the rows of tr_other
product = [[sum((ele_a*ele_b) % self.prime for ele_a, ele_b in zip(row_a, col_b))
for col_b in tr_other.m] for row_a in self.m]
return ResFieldMatrix(product, self.prime)
return NotImplemented
def __rmul__(self, other):
if isinstance(other, int):
return self.__mul__(other)
return NotImplemented
def __add__(self, other):
if isinstance(other, ResFieldMatrix):
if other.prime != self.prime:
raise ArithmeticError("Tried to add elements of different residue fields.")
if self.width != other.width or self.height != other.height:
raise ArithmeticError("Tried to add matrices with unfitting dimensions.")
summed = [[(ele_a + ele_b) % self.prime for ele_a, ele_b in zip(row_a, row_b)]
for row_a, row_b in zip(other.m, self.m)]
return ResFieldMatrix(summed, self.prime)
return NotImplemented
def __sub__(self, other):
if isinstance(other, ResFieldMatrix):
if other.prime != self.prime:
raise ArithmeticError("Tried to subtract elements of different residue fields.")
if self.width != other.width or self.height != other.height:
raise ArithmeticError("Tried to subtract matrices with unfitting dimensions.")
summed = [[(ele_a - ele_b) % self.prime for ele_a, ele_b in zip(row_a, row_b)]
for row_a, row_b in zip(other.m, self.m)]
return ResFieldMatrix(summed, self.prime)
return NotImplemented
def input_matrix():
"""Creates a python matrix from command line input"""
print("Type matrix. Press CTRL-D to end input:")
text = sys.stdin.read()
lines = text.split("\n")
max_width = []
elements = []
for i, l in zip(range(len(lines)), lines):
l = l.split(" ")
l = [e for e in l if e != ""]
if len(l) != 0:
elements.append(l)
for j, e, old_max in zip_longest(range(max(len(elements[-1]), len(max_width))),
elements[-1], max_width, fillvalue=0):
if e != 0:
if j >= len(max_width):
max_width.append(len(e))
else:
max_width[j] = max(len(e), old_max)
for i in range(len(elements)):
elements[i] = [e.rjust(col_width) for e, col_width in zip(elements[i], max_width)]
elements[i] = "[" + ", ".join(elements[i]) + "]"
return "[" + ",\n ".join(elements) + "]"
std_3_transversal = Transversal(range(-1,2), 3)
A = [[ 1, -9, 3, 5],
[ 7, -3, -6, 8],
[ 5, -2, -4, 6],
[-5, 1, 2, -8]]
C = [[ 1, 2, -3, 5],
[-1, 0, 5, -7],
[ 0, 2, 4, -4]]
A = ResFieldMatrix(A, 3)
C = ResFieldMatrix(C, 3)
res = (A*C.transpose()).to_transversal(std_3_transversal)
pprint(res)
#print(input_matrix())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment