Skip to content

Instantly share code, notes, and snippets.

@PirosB3
Created May 3, 2015 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save PirosB3/0aac2b7204023a1b4d7f to your computer and use it in GitHub Desktop.
Save PirosB3/0aac2b7204023a1b4d7f to your computer and use it in GitHub Desktop.
def matrix_mul(m1, m2):
s1, s2 = m1
k1, k2 = m2
cost = s1 * s2 * k2
resulting_matrix = (s1, k2,)
return cost, resulting_matrix
def matrix_chain_mult(matrices):
if len(matrices) == 0:
return 0, None
if len(matrices) == 1:
return 0, matrices[0]
min_cost = float('inf')
min_size = None
for idx in xrange(1, len(matrices)):
current = matrices[:idx]
rest = matrices[idx:]
current_cost, current_size = matrix_chain_mult(current)
rest_cost, rest_size = matrix_chain_mult(rest)
if rest_size is None:
resulting_cost = 0
resulting_size = current_size
elif current_size[1] == rest_size[0]:
resulting_cost, resulting_size = matrix_mul(current_size, rest_size)
else:
continue
total_cost = current_cost + rest_cost + resulting_cost
if total_cost < min_cost:
min_cost = total_cost
min_size = resulting_size
return min_cost, min_size
matrices = [(10, 100,), (100, 5,), (5, 50,), (50, 1,)]
print matrix_chain_mult(matrices)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment