Skip to content

Instantly share code, notes, and snippets.

@mgd020
Created September 30, 2020 13:17
Show Gist options
  • Save mgd020/59cfedb825b8fca685d4912c94396dc7 to your computer and use it in GitHub Desktop.
Save mgd020/59cfedb825b8fca685d4912c94396dc7 to your computer and use it in GitHub Desktop.
from fractions import Fraction, gcd
def matrix_determinant(m):
"""determinant using laplace transform"""
size = len(m)
assert len(m[0]) == size
if size == 2:
return m[0][0] * m[1][1] - m[0][1] * m[1][0]
return sum(
(1 - ((d & 1) << 1))
* m[0][x]
* matrix_determinant(matrix_select(m, range(1, size), [i for i in xrange(size) if i != x]))
for d, x in enumerate(range(size))
)
def matrix_select(m, rows, cols):
return [[m[j][i] for i in cols] for j in rows]
def matrix_identity(size, one=1, zero=0):
return [[one if i == j else zero for j in xrange(size)] for i in xrange(size)]
def matrix_transpose(m):
rows = len(m)
cols = len(m[0])
return [[m[j][i] for j in xrange(rows)] for i in xrange(cols)]
def matrix_sub(m1, m2):
rows = len(m1)
assert rows == len(m2)
cols = len(m1[0])
assert cols == len(m2[0])
return [[m1[i][j] - m2[i][j] for j in xrange(cols)] for i in xrange(rows)]
def matrix_div_scalar(m, scalar):
rows = len(m)
cols = len(m[0])
return [[m[j][i] / scalar for i in xrange(cols)] for j in xrange(rows)]
def matrix_invert(m):
size = len(m)
assert len(m[0]) == size
if size == 1:
return [[1 / m[0][0]]]
if size == 2:
return matrix_div_scalar([[m[1][1], -m[0][1]], [-m[1][0], m[0][0]]], matrix_determinant(m))
return matrix_div_scalar(
matrix_transpose(
[
[
(1 - (((i + j) & 1) << 1))
* matrix_determinant(
matrix_select(m, [jj for jj in xrange(size) if jj != j], [ii for ii in xrange(size) if ii != i])
)
for i in xrange(size)
]
for j in xrange(size)
]
),
matrix_determinant(m),
)
def matrix_dot_product(m1, m2):
size = len(m1[0])
assert size == len(m2)
return [[sum(m1[j][k] * m2[k][i] for k in xrange(size)) for i in xrange(len(m2[0]))] for j in xrange(len(m1))]
def lcm(numbers):
"""lowest common multiple"""
return reduce((lambda a, b: a * b // gcd(a, b)), numbers)
def solution(m):
# markov chain
# re-order into [[I, O], [R, Q]]
terminal_states = []
non_terminal_states = []
for i, row in enumerate(m):
row_sum = sum(row)
if row_sum:
non_terminal_states.append(i)
row[:] = [Fraction(a, row_sum) for a in row]
else:
terminal_states.append(i)
row[:] = [Fraction(0)] * len(row)
for i in terminal_states:
m[i][i] = Fraction(1)
if len(terminal_states) == 1:
return [1, 1]
state_reorder = terminal_states + non_terminal_states
m = [[m[i][j] for j in state_reorder] for i in state_reorder]
# FR = ((I - Q) ** -1) R
non_terminal_range = range(len(terminal_states), len(m))
fr = matrix_dot_product(
matrix_invert(
matrix_sub(
matrix_identity(len(non_terminal_range), one=Fraction(1), zero=Fraction(0)),
matrix_select(m, non_terminal_range, non_terminal_range),
)
),
matrix_select(m, non_terminal_range, range(len(terminal_states))),
)
terminal_probs = fr[0]
denominator = lcm(f.denominator for f in terminal_probs)
numerators = [f.numerator * (denominator / f.denominator) for f in terminal_probs]
return numerators + [denominator]
assert solution(
[
[0, 1, 0, 0, 0, 1],
[4, 0, 0, 3, 2, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
) == [0, 3, 2, 9, 14]
assert solution([[0, 2, 1, 0, 0], [0, 0, 0, 3, 4], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) == [7, 6, 8, 21]
assert solution(
[
[0, 1, 0, 0, 0, 1],
[4, 0, 0, 3, 2, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
) == [0, 3, 2, 9, 14]
assert solution([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1],]) == [1, 1, 2]
assert solution([[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) == [1, 1]
assert solution([[1, 1, 0, 1], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],]) == [0, 1, 1]
assert solution(
[
[0, 1, 0, 0, 0, 1],
[1, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
) == [0, 1, 1, 3, 5]
assert solution(
[
[0, 1, 0, 0, 0, 1],
[4, 0, 0, 3, 2, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
) == [0, 3, 2, 9, 14]
assert solution([[0, 2, 1, 0, 0], [0, 0, 0, 3, 4], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) == [7, 6, 8, 21]
assert solution([[0, 1, 1], [0, 0, 0], [0, 1, 0],]) == [1, 1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment