Created
September 30, 2020 13:17
-
-
Save mgd020/59cfedb825b8fca685d4912c94396dc7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fractions import Fraction, gcd | |
def matrix_determinant(m): | |
"""determinant using laplace transform""" | |
size = len(m) | |
assert len(m[0]) == size | |
if size == 2: | |
return m[0][0] * m[1][1] - m[0][1] * m[1][0] | |
return sum( | |
(1 - ((d & 1) << 1)) | |
* m[0][x] | |
* matrix_determinant(matrix_select(m, range(1, size), [i for i in xrange(size) if i != x])) | |
for d, x in enumerate(range(size)) | |
) | |
def matrix_select(m, rows, cols): | |
return [[m[j][i] for i in cols] for j in rows] | |
def matrix_identity(size, one=1, zero=0): | |
return [[one if i == j else zero for j in xrange(size)] for i in xrange(size)] | |
def matrix_transpose(m): | |
rows = len(m) | |
cols = len(m[0]) | |
return [[m[j][i] for j in xrange(rows)] for i in xrange(cols)] | |
def matrix_sub(m1, m2): | |
rows = len(m1) | |
assert rows == len(m2) | |
cols = len(m1[0]) | |
assert cols == len(m2[0]) | |
return [[m1[i][j] - m2[i][j] for j in xrange(cols)] for i in xrange(rows)] | |
def matrix_div_scalar(m, scalar): | |
rows = len(m) | |
cols = len(m[0]) | |
return [[m[j][i] / scalar for i in xrange(cols)] for j in xrange(rows)] | |
def matrix_invert(m): | |
size = len(m) | |
assert len(m[0]) == size | |
if size == 1: | |
return [[1 / m[0][0]]] | |
if size == 2: | |
return matrix_div_scalar([[m[1][1], -m[0][1]], [-m[1][0], m[0][0]]], matrix_determinant(m)) | |
return matrix_div_scalar( | |
matrix_transpose( | |
[ | |
[ | |
(1 - (((i + j) & 1) << 1)) | |
* matrix_determinant( | |
matrix_select(m, [jj for jj in xrange(size) if jj != j], [ii for ii in xrange(size) if ii != i]) | |
) | |
for i in xrange(size) | |
] | |
for j in xrange(size) | |
] | |
), | |
matrix_determinant(m), | |
) | |
def matrix_dot_product(m1, m2): | |
size = len(m1[0]) | |
assert size == len(m2) | |
return [[sum(m1[j][k] * m2[k][i] for k in xrange(size)) for i in xrange(len(m2[0]))] for j in xrange(len(m1))] | |
def lcm(numbers): | |
"""lowest common multiple""" | |
return reduce((lambda a, b: a * b // gcd(a, b)), numbers) | |
def solution(m): | |
# markov chain | |
# re-order into [[I, O], [R, Q]] | |
terminal_states = [] | |
non_terminal_states = [] | |
for i, row in enumerate(m): | |
row_sum = sum(row) | |
if row_sum: | |
non_terminal_states.append(i) | |
row[:] = [Fraction(a, row_sum) for a in row] | |
else: | |
terminal_states.append(i) | |
row[:] = [Fraction(0)] * len(row) | |
for i in terminal_states: | |
m[i][i] = Fraction(1) | |
if len(terminal_states) == 1: | |
return [1, 1] | |
state_reorder = terminal_states + non_terminal_states | |
m = [[m[i][j] for j in state_reorder] for i in state_reorder] | |
# FR = ((I - Q) ** -1) R | |
non_terminal_range = range(len(terminal_states), len(m)) | |
fr = matrix_dot_product( | |
matrix_invert( | |
matrix_sub( | |
matrix_identity(len(non_terminal_range), one=Fraction(1), zero=Fraction(0)), | |
matrix_select(m, non_terminal_range, non_terminal_range), | |
) | |
), | |
matrix_select(m, non_terminal_range, range(len(terminal_states))), | |
) | |
terminal_probs = fr[0] | |
denominator = lcm(f.denominator for f in terminal_probs) | |
numerators = [f.numerator * (denominator / f.denominator) for f in terminal_probs] | |
return numerators + [denominator] | |
assert solution( | |
[ | |
[0, 1, 0, 0, 0, 1], | |
[4, 0, 0, 3, 2, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
] | |
) == [0, 3, 2, 9, 14] | |
assert solution([[0, 2, 1, 0, 0], [0, 0, 0, 3, 4], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) == [7, 6, 8, 21] | |
assert solution( | |
[ | |
[0, 1, 0, 0, 0, 1], | |
[4, 0, 0, 3, 2, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
] | |
) == [0, 3, 2, 9, 14] | |
assert solution([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1],]) == [1, 1, 2] | |
assert solution([[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) == [1, 1] | |
assert solution([[1, 1, 0, 1], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],]) == [0, 1, 1] | |
assert solution( | |
[ | |
[0, 1, 0, 0, 0, 1], | |
[1, 0, 0, 1, 1, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
] | |
) == [0, 1, 1, 3, 5] | |
assert solution( | |
[ | |
[0, 1, 0, 0, 0, 1], | |
[4, 0, 0, 3, 2, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
] | |
) == [0, 3, 2, 9, 14] | |
assert solution([[0, 2, 1, 0, 0], [0, 0, 0, 3, 4], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) == [7, 6, 8, 21] | |
assert solution([[0, 1, 1], [0, 0, 0], [0, 1, 0],]) == [1, 1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment