from import Iterable
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
def make_graph(nodes: Iterable, M: np.ndarray, G: nx.classes.digraph.DiGraph=nx.DiGraph()):
"""Build graph based on list of nodes and a weight matrix
nodes (list): Graph nodes
M (np.ndarray): Weight matrix
G (nx.classes.digraph.DiGraph, optional): Graph type. Defaults to nx.DiGraph().
[type]: Graph object
>>> n_nodes = 4
>>> M = np.random.rand(n_nodes, n_nodes)
>>> nodes = range(M.shape[0])
>>> G = make_graph(nodes, M)
for node in nodes:
G.add_node(node, label=f"{node}")
for i, origin_node in enumerate(nodes):
for j, destination_node in enumerate(nodes):
if M[i, j] != 0:
G.add_edge(origin_node, destination_node, weight=M[i, j]
,label=f"{M[i, j]:0.02f}")
return G
def graphplot(G: nx.classes.digraph.DiGraph, M: np.ndarray
,min_weight_threshold: float=0.0, bins: int=4
,graph_layout: str="spring_layout"
,figsize: tuple=(20, 10)
,edge_kwargs=None, node_label_kwargs=None, node_kwargs=None
"""Plot a graph with weights on edges
G (nx.classes.digraph.DiGraph): Weighted graph
M (np.ndarray): Weight matrix
min_weight_threshold (float, optional): Minimal weight to be plotted. Defaults to 0.0.
bins (int, optional): Number of bins to divide the weights. Defaults to 4.
graph_layout (str, optional): Defaults to "spring_layout".
figsize (tuple, optional): Defaults to (20, 10).
cmap ([type], optional): Matplotlib colormap. Defaults to
edge_kwargs ([type], optional): Kwargs to edge plot. Defaults to None.
ax: Plotted graph
>>> n_nodes = 4
>>> M = np.random.rand(n_nodes, n_nodes)
>>> nodes = range(M.shape[0])
>>> G = make_graph(nodes, M)
>>> graphplot(G, M)
node_kwargs = node_kwargs or {"node_color": "k", "node_size": 500}
edge_kwargs = edge_kwargs or {"edge_color" :nx.get_edge_attributes(G, 'weight').values()
,"edge_cmap": cmap
,"width": 4
,"connectionstyle":'arc3, rad=0.2'
node_label_kwargs = node_label_kwargs or {"font_color": "w", "font_size": 16
,"font_weight": "bold"
pos = getattr(nx, graph_layout)(G)
fig, ax = plt.subplots(figsize=figsize)
nx.draw_networkx_nodes(G, pos, ax=ax, **node_kwargs)
nx.draw_networkx_labels(G, pos, labels=nx.get_node_attributes(G, 'label')
,ax=ax, **node_label_kwargs)
edges = nx.draw_networkx_edges(G, pos, ax=ax, **edge_kwargs)
# Configure colorbar
_, bin_edges = np.histogram(, mask=M==min_weight_threshold).compressed()
pc = mpl.collections.PatchCollection(edges, cmap=cmap)
cmap_array = list(bin_edges)
cbar = plt.colorbar(pc);
cbar.set_label('weights', rotation=270, fontsize=16, labelpad=20)
# ax = plt.gca()
return ax
