Last active
December 8, 2021 19:20
-
-
Save zedchance/19d9314a05eb98571749835244105dd5 to your computer and use it in GitHub Desktop.
Optimal parenthesization of matrix multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# optimized parenthesization of matrix multiplication | |
import math | |
def print_2d(arr): | |
for row in arr: | |
for col in row: | |
if col == math.inf: | |
print(col, end=" ") | |
else: | |
print(f'{col:8}', end=" ") | |
print() | |
print() | |
def optimal_chain_order(p): | |
n = len(p) - 1 | |
m = [[-1 for i in range(n)] for i in range(n)] | |
s = [[-1 for i in range(n)] for i in range(n)] | |
# base cases | |
for i in range(n): | |
m[i][i] = 0 | |
# calculate sub problems | |
for j in range(n): | |
for i in range(j - 1, -1, -1): | |
m[i][j] = math.inf | |
for k in range(i, j): | |
q = m[i][k] + m[k + 1][j] + (p[i - 1] * p[k] * p[j]) | |
# print(i, k, j, "p val", p[i - 1] * p[k] * p[j]) | |
if q < m[i][j]: | |
m[i][j] = q | |
s[i][j] = k | |
print() | |
print_2d(m) | |
print_2d(s) | |
def _print_opt_parens(i, j): | |
if i == j: | |
print(f' A{i} ', end="") | |
else: | |
print("(", end="") | |
_print_opt_parens(i, s[i][j]) | |
_print_opt_parens(s[i][j] + 1, j) | |
print(")", end="") | |
_print_opt_parens(0, n - 1) | |
print(f'\n{m[0][-1]} multiplications needed') | |
if __name__ == '__main__': | |
# A0 = 10 by 100 matrix | |
# A1 = 100 by 5 | |
# A2 = 5 by 50 | |
# A0 * A1 * A2 | |
p = [10, 100, 5, 50] | |
optimal_chain_order(p) | |
p2 = [10, 20, 50, 100, 20, 12, 40, 30] | |
optimal_chain_order(p2) |
Author
zedchance
commented
Nov 17, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment