Skip to content

Instantly share code, notes, and snippets.

@ntcho
Created May 7, 2023 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ntcho/41f846eb21aff8ce8743980cc24bf6c3 to your computer and use it in GitHub Desktop.
Save ntcho/41f846eb21aff8ce8743980cc24bf6c3 to your computer and use it in GitHub Desktop.
Matrix chain multiplication problem in Python
# Matrix Chain Multiplication Optimization
# size of the matrix in the matrix chain
# A_1 is a p[0] by p[1] matrix, A_2 is a p[1] by p[2] matrix
p = [10, 20, 12, 17, 22, 19, 50]
# length of the matrix chain + 1
matrix_size = len(p)
# initialize matrix with zeros
m = [[0 for _ in range(matrix_size)] for _ in range(matrix_size)]
s = [[0 for _ in range(matrix_size)] for _ in range(matrix_size)]
latex = [["" for _ in range(matrix_size)] for _ in range(matrix_size)]
def calc_m_and_s(a, b):
# individual matrix
if a == b:
return 0
# return calculated values
if m[a][b] != 0:
return m[a][b]
# adjacent matrix
if a + 1 == b:
latex[a][
b
] = f"m[{a},{b}] = p_{a - 1}·p_{a}·p_{b} = {p[a - 1]}·{p[a]}·{p[b]} = {p[a - 1] * p[a] * p[b]}. \\\\"
m[a][b] = p[a - 1] * p[a] * p[b]
s[a][b] = a
return m[a][b]
# matrix chain
splits = []
splits_latex_equation = []
splits_latex_values = []
window_size = b - a
for i in range(window_size):
splits.append(
calc_m_and_s(a, a + i)
+ calc_m_and_s(a + i + 1, b)
+ p[a - 1] * p[a + i] * p[b]
)
splits_latex_equation.append(
f"m[{a},{a+i}]+m[{a+i+1},{b}]+p_{a-1}·p_{a+i}·p_{b}"
)
splits_latex_values.append(
f"{m[a][a+i]}+{m[a+i+1][b]}+{p[a-1]}·{p[a+i]}·{p[b]}"
)
optimal = min(splits)
m[a][b] = optimal
s[a][b] = a + splits.index(optimal)
latex[a][b] = "\n".join(
[
"m[{0},{1}] = min\\left\\{{\\begin{{array}}{{lr}}".format(a, b),
", \\\\\n".join(splits_latex_equation),
"\end{array}\\right\\}",
"\\\\[0.5em] \hspace{31px} ",
"= min\\left\\{\\begin{array}{lr}",
", \\\\\n".join(splits_latex_values),
"\end{array}\\right\\}",
"\\\\[0.25em] \hspace{31px} ",
"= min\\{{{0}\\}}".format(", ".join(map(str, splits))),
"\\\\[0.25em] \hspace{31px} ",
"= {0}.".format(m[a][b]),
"\\\\[0.5em]",
"s[{0},{1}] = {2}.".format(a, b, s[a][b]),
"\\\\[1em]",
]
)
return optimal
calc_m_and_s(1, len(p) - 1)
print("m = ")
print("\n".join(["\t".join(map(str, r[1:])) for r in m[1:]]))
print()
print("s = ")
print("\n".join(["\t".join(map(str, r[1:])) for r in s[1:]]))
diagonal = 5
for i in range(1, len(p) - diagonal):
print(latex[i][i + diagonal])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment