Skip to content

Instantly share code, notes, and snippets.

@zedchance
Last active December 8, 2021 19:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zedchance/19d9314a05eb98571749835244105dd5 to your computer and use it in GitHub Desktop.
Save zedchance/19d9314a05eb98571749835244105dd5 to your computer and use it in GitHub Desktop.
Optimal parenthesization of matrix multiplication
# optimized parenthesization of matrix multiplication
import math
def print_2d(arr):
for row in arr:
for col in row:
if col == math.inf:
print(col, end=" ")
else:
print(f'{col:8}', end=" ")
print()
print()
def optimal_chain_order(p):
n = len(p) - 1
m = [[-1 for i in range(n)] for i in range(n)]
s = [[-1 for i in range(n)] for i in range(n)]
# base cases
for i in range(n):
m[i][i] = 0
# calculate sub problems
for j in range(n):
for i in range(j - 1, -1, -1):
m[i][j] = math.inf
for k in range(i, j):
q = m[i][k] + m[k + 1][j] + (p[i - 1] * p[k] * p[j])
# print(i, k, j, "p val", p[i - 1] * p[k] * p[j])
if q < m[i][j]:
m[i][j] = q
s[i][j] = k
print()
print_2d(m)
print_2d(s)
def _print_opt_parens(i, j):
if i == j:
print(f' A{i} ', end="")
else:
print("(", end="")
_print_opt_parens(i, s[i][j])
_print_opt_parens(s[i][j] + 1, j)
print(")", end="")
_print_opt_parens(0, n - 1)
print(f'\n{m[0][-1]} multiplications needed')
if __name__ == '__main__':
# A0 = 10 by 100 matrix
# A1 = 100 by 5
# A2 = 5 by 50
# A0 * A1 * A2
p = [10, 100, 5, 50]
optimal_chain_order(p)
p2 = [10, 20, 50, 100, 20, 12, 40, 30]
optimal_chain_order(p2)
@zedchance
Copy link
Author

       0    50000     7500 
      -1        0     5000 
      -1       -1        0 

      -1        0        0 
      -1       -1        1 
      -1       -1       -1 

( A0 ( A1  A2 ))
7500 multiplications needed

       0     6000    25000    90000    86000    86000    99200 
      -1        0    10000    60000    80000    82400    87200 
      -1       -1        0   100000   120000    96000   105600 
      -1       -1       -1        0   100000    84000   108000 
      -1       -1       -1       -1        0    24000    72000 
      -1       -1       -1       -1       -1        0     9600 
      -1       -1       -1       -1       -1       -1        0 

      -1        0        0        0        0        0        0 
      -1       -1        1        2        3        4        5 
      -1       -1       -1        2        2        2        5 
      -1       -1       -1       -1        3        3        5 
      -1       -1       -1       -1       -1        4        5 
      -1       -1       -1       -1       -1       -1        5 
      -1       -1       -1       -1       -1       -1       -1 

( A0 ((((( A1  A2 ) A3 ) A4 ) A5 ) A6 ))
99200 multiplications needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment