Skip to content

Instantly share code, notes, and snippets.

@hsteinshiromoto
Last active September 3, 2021 04:26
Show Gist options
  • Save hsteinshiromoto/e2e25814104004a4516e65023da5c8e6 to your computer and use it in GitHub Desktop.
Save hsteinshiromoto/e2e25814104004a4516e65023da5c8e6 to your computer and use it in GitHub Desktop.
nx.plotgraph
from collections.abc import Iterable
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
def make_graph(nodes: Iterable, M: np.ndarray, G: nx.classes.digraph.DiGraph=nx.DiGraph()):
"""Build graph based on list of nodes and a weight matrix
Args:
nodes (list): Graph nodes
M (np.ndarray): Weight matrix
G (nx.classes.digraph.DiGraph, optional): Graph type. Defaults to nx.DiGraph().
Returns:
[type]: Graph object
Example:
>>> n_nodes = 4
>>> M = np.random.rand(n_nodes, n_nodes)
>>> nodes = range(M.shape[0])
>>> G = make_graph(nodes, M)
"""
for node in nodes:
G.add_node(node, label=f"{node}")
for i, origin_node in enumerate(nodes):
for j, destination_node in enumerate(nodes):
if M[i, j] != 0:
G.add_edge(origin_node, destination_node, weight=M[i, j]
,label=f"{M[i, j]:0.02f}")
return G
def graphplot(G: nx.classes.digraph.DiGraph, M: np.ndarray
,min_weight_threshold: float=0.0, bins: int=4
,graph_layout: str="spring_layout"
,figsize: tuple=(20, 10)
,cmap=plt.cm.coolwarm
,edge_kwargs=None, node_label_kwargs=None, node_kwargs=None
):
"""Plot a graph with weights on edges
Args:
G (nx.classes.digraph.DiGraph): Weighted graph
M (np.ndarray): Weight matrix
min_weight_threshold (float, optional): Minimal weight to be plotted. Defaults to 0.0.
bins (int, optional): Number of bins to divide the weights. Defaults to 4.
graph_layout (str, optional): Defaults to "spring_layout".
figsize (tuple, optional): Defaults to (20, 10).
cmap ([type], optional): Matplotlib colormap. Defaults to plt.cm.coolwarm.
edge_kwargs ([type], optional): Kwargs to edge plot. Defaults to None.
Returns:
ax: Plotted graph
Example:
>>> n_nodes = 4
>>> M = np.random.rand(n_nodes, n_nodes)
>>> nodes = range(M.shape[0])
>>> G = make_graph(nodes, M)
>>> graphplot(G, M)
References:
[1] https://networkx.org/documentation/stable/auto_examples/drawing/plot_directed.html
"""
node_kwargs = node_kwargs or {"node_color": "k", "node_size": 500}
edge_kwargs = edge_kwargs or {"edge_color" :nx.get_edge_attributes(G, 'weight').values()
,"edge_cmap": cmap
,"width": 4
,"connectionstyle":'arc3, rad=0.2'
}
node_label_kwargs = node_label_kwargs or {"font_color": "w", "font_size": 16
,"font_weight": "bold"
}
pos = getattr(nx, graph_layout)(G)
fig, ax = plt.subplots(figsize=figsize)
nx.draw_networkx_nodes(G, pos, ax=ax, **node_kwargs)
nx.draw_networkx_labels(G, pos, labels=nx.get_node_attributes(G, 'label')
,ax=ax, **node_label_kwargs)
edges = nx.draw_networkx_edges(G, pos, ax=ax, **edge_kwargs)
# Configure colorbar
_, bin_edges = np.histogram(
np.ma.masked_array(M, mask=M==min_weight_threshold).compressed()
,bins=bins)
pc = mpl.collections.PatchCollection(edges, cmap=cmap)
cmap_array = list(bin_edges)
pc.set_array(cmap_array)
cbar = plt.colorbar(pc);
cbar.set_label('weights', rotation=270, fontsize=16, labelpad=20)
# ax = plt.gca()
ax.set_axis_off()
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment