Skip to content

Instantly share code, notes, and snippets.

@dmikushin
Forked from V0XNIHILI/matrix_as_graph.ipynb
Created January 29, 2024 16:48
Show Gist options
  • Save dmikushin/6d9a0742532450c732ef5fb73af1e8f4 to your computer and use it in GitHub Desktop.
Save dmikushin/6d9a0742532450c732ef5fb73af1e8f4 to your computer and use it in GitHub Desktop.
Matrix multiplication in graph form
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Matrix multiplication in graph form\n",
"\n",
"I pondered about the idea of multiplying matrices in graph form and found this amazing page (https://www.math3ma.com/blog/matrices-probability-graphs) where they explain very clearly how to do this. Based on this explanation, I built a program to do the procedure for multiplying matrices in graph format automatically."
]
},
{
"cell_type": "code",
"execution_count": 324,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Matrices to multiply"
]
},
{
"cell_type": "code",
"execution_count": 325,
"metadata": {},
"outputs": [],
"source": [
"matrix1 = np.array([[-1, 2], [4, -3], [1, 2]])\n",
"matrix2 = np.array([[5, 1], [-7, -1]])\n",
"matrix3 = np.array([[5, 1], [6, 7]])\n",
"matrix4 = np.array([[1, 7, 8, -2], [-7, 5, 6, 3]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utility functions"
]
},
{
"cell_type": "code",
"execution_count": 327,
"metadata": {},
"outputs": [],
"source": [
"def create_graph_from_matrix(matrix: np.ndarray) -> nx.DiGraph:\n",
" G = nx.DiGraph()\n",
"\n",
" for dim, axis, layer in zip(matrix.shape, [\"i\", \"j\"], [0, 1]):\n",
" G.add_nodes_from([f\"{axis}_{dim_entry+1}\" for dim_entry in range(dim)], layer=layer)\n",
"\n",
" for i, j in np.ndindex(matrix.shape):\n",
" if matrix[i, j] != 0:\n",
" G.add_edge(f\"i_{i+1}\", f\"j_{j+1}\", weight=matrix[i, j])\n",
"\n",
" return G"
]
},
{
"cell_type": "code",
"execution_count": 331,
"metadata": {},
"outputs": [],
"source": [
"def get_max_k(nodes: list):\n",
" return max(map(lambda x: int(x.split('k')[1].split('_')[0]), filter(lambda x: x.startswith('k'), nodes)), default=-1)\n",
" \n",
"def new_node_name(old_name: str, step: int) -> str:\n",
" axis, index = old_name.split(\"_\")\n",
" \n",
" if axis == \"i\":\n",
" return f\"j_{index}\"\n",
"\n",
" if axis == \"j\":\n",
" return f\"final_i_{index}\"\n",
"\n",
" if axis.startswith('k'):\n",
" old_index = int(old_name.split(\"k\")[1].split(\"_\")[0])\n",
"\n",
" return f\"k{old_index + step}_{old_name.split('_')[1]}\"\n",
"\n",
"def get_layer_index(node: str, max_k: int):\n",
" if node.startswith('i'):\n",
" return 0\n",
"\n",
" if node.startswith('k'):\n",
" return int(node.split('k')[1].split('_')[0]) + 1\n",
"\n",
" if node.startswith('j'):\n",
" return max_k + 2\n",
"\n",
"def multiply_graphs(G1: nx.DiGraph, G2: nx.DiGraph) -> nx.DiGraph:\n",
" max_index = get_max_k(G1.nodes)\n",
"\n",
" G_2_relabeled = nx.relabel_nodes(G2, {node: new_node_name(node, max_index + 2) for node in G2.nodes})\n",
" G_composed = nx.compose(G1, G_2_relabeled)\n",
" G_composed_relabeled = nx.relabel_nodes(G_composed, {node: node.replace(\"j\", f\"k{max_index+1}\").replace(\"final_i\", \"j\") for node in G_composed.nodes})\n",
"\n",
" max_relabeled_index = get_max_k(G_composed_relabeled.nodes)\n",
"\n",
" attrs = {node: {\"layer\": get_layer_index(node, max_relabeled_index)} for node in G_composed_relabeled.nodes}\n",
"\n",
" nx.set_node_attributes(G_composed_relabeled, attrs)\n",
"\n",
" return G_composed_relabeled"
]
},
{
"cell_type": "code",
"execution_count": 332,
"metadata": {},
"outputs": [],
"source": [
"def plot_matrix(G: nx.DiGraph):\n",
" pos = nx.multipartite_layout(G, subset_key='layer')\n",
" \n",
" nx.draw(G, with_labels=True, pos=pos)\n",
"\n",
" edge_labels = nx.get_edge_attributes(G, \"weight\")\n",
" nx.draw_networkx_edge_labels(G, pos, edge_labels, label_pos=0.6)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Graph multiplication"
]
},
{
"cell_type": "code",
"execution_count": 333,
"metadata": {},
"outputs": [],
"source": [
"G1 = create_graph_from_matrix(matrix1)\n",
"G2 = create_graph_from_matrix(matrix2)\n",
"G3 = create_graph_from_matrix(matrix3)\n",
"G4 = create_graph_from_matrix(matrix4)"
]
},
{
"cell_type": "code",
"execution_count": 341,
"metadata": {},
"outputs": [],
"source": [
"G = multiply_graphs(multiply_graphs(multiply_graphs(G1, G2), G3), G4)"
]
},
{
"cell_type": "code",
"execution_count": 342,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_matrix(G)"
]
},
{
"cell_type": "code",
"execution_count": 339,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 167, -991, -1144, 106],\n",
" [ -383, 2179, 2516, -224],\n",
" [ 61, -437, -504, 54]])"
]
},
"execution_count": 339,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"((matrix1 @ matrix2) @ matrix3) @ matrix4"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment