Skip to content

Instantly share code, notes, and snippets.

@jkapila
Created September 18, 2021 18:46
Show Gist options
  • Save jkapila/7c103c35e942e7d9eb5a4c195ab9717e to your computer and use it in GitHub Desktop.
Save jkapila/7c103c35e942e7d9eb5a4c195ab9717e to your computer and use it in GitHub Desktop.
# importing required libraries
import matplotlib.pyplot as plt
import matplotlib as mpl
import networkx as nx
# function to plot bayesian network
def plot_bn_graph(G,title='',label_fontsize = 14,node_color='b',node_size=10,figsize=(12,10),position=False,layout='circular',cmap='YlOrRd'):
"""
Plotting Network diagram using Networkx and Matplotlib only.
:param G: A Networkx Digraph havinf nodes as variables and edges and relationship
:param title: Title for plot. Default to "Structure Plot"
:param label_fontsize: Fontzise of the variable lables on the plot.
:param node_color: Color of the node, Defaults to "blue"
:param node_size: multipler for size of node. Defaults to 10.
:param figsize: Figure size fo the Plot shoudl be passed as tuple (w,h)
:param position: If 'pos' is an attribute of graphs then use it to specify coordinate of node on the graph.
Default to false. If set to True, 'layout' has no impact.
:param layout: 3 type of layouts avaialble "circular", "spring" and "spectral. Defaults to "Circular"
:param cmap: A color gradient used to draw edges and is based on weights. Can take any matplotlib colormap.
Example Usage:
G = nx.DiGraph()
nodes = ['Intelligence','Course','Grades','SAT','Job']
positions = [(4,3),(2,3),(3,2),(5,2),(3,1)]
for nm,pos in zip(nodes,positions):
G.add_node(nm,pos=pos)
G.add_edge('Course','Grades',weight=0.5)
G.add_edge('Intelligence','Grades',weight=0.5)
G.add_edge('Intelligence','SAT',weight=1)
G.add_edge('Grades','Job',weight=1.5)
plot_bn_graph(G,title='A basic Student Bayesian Network plot',position=True,node_size=200,label_fontsize=16,node_color='skyblue',cmap='Wistia')
"""
fig, ax = plt.subplots(figsize = figsize)
cmap = plt.get_cmap(cmap)
bg_color = '#183861'
fg_color = 'white'
# set title plus title color
ax.set_title('Structure Plot' if len(title)==0 else title, color=fg_color, fontsize=16)
# set figure facecolor
ax.patch.set_facecolor(bg_color)
# layouts
layouts = {'spring':nx.layout.spring_layout,
'circular':nx.layout.circular_layout,
'spectral':nx.layout.spectral_layout}
if position:
pos=nx.get_node_attributes(G,'pos')
elif layout not in layouts.keys():
pos = nx.layout.planar_layout(G)
else:
pos = layouts[layout](G)
node_sizes = [3 + node_size * len(G) for _ in range(len(G))]
edge_weights = G.edges(data="weight")
M = len(edge_weights)
if sum([x[2] is None for x in edge_weights]) == M:
edge_colors = [5 for _ in edge_weights]
edge_alphas = [1 for _ in edge_weights]
print('No weights in edges')
else:
edge_colors = [np.nan if x[2] is None else abs(x[2]) for x in edge_weights]
edge_weights = [0.85 for _ in edge_weights]
minw, maxw = min(edge_weights),max(edge_weights)
edge_alphas = [ (x - minw + 0.01) / (maxw - minw + 0.01) for x in edge_weights]
# M = G.number_of_edges()
# edge_colors = range(2, M + 2)
# edge_alphas = [(5 + i) / (M + 4) for i in range(M)]
nodes = nx.draw_networkx_nodes(G, pos, ax=ax, node_size=node_sizes,
node_color=node_color)
edges = nx.draw_networkx_edges(G, pos, ax=ax, node_size=node_sizes, arrowstyle="->",
arrowsize=12, edge_color=edge_colors, edge_cmap=cmap, width=3,)
labels = nx.draw_networkx_labels(G, pos, ax=ax, font_size=label_fontsize, font_color=fg_color,
verticalalignment ='bottom')
# set alpha value for each edge
for i in range(M):
edges[i].set_alpha(edge_alphas[i])
# the colorbar and its properties
pc = mpl.collections.PatchCollection(edges, cmap=cmap)
pc.set_array(edge_colors)
cb = plt.colorbar(pc,ax=ax)
# set colorbar label plus label color
cb.set_label(label='Weights',weight='bold', color=fg_color)
# set colorbar tick color
cb.ax.yaxis.set_tick_params(color=fg_color)
# set colorbar edgecolor
cb.outline.set_edgecolor(fg_color)
# set colorbar ticklabels
plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color=fg_color)
gc = plt.gca()
gc.set_axis_off()
fig.set_facecolor(bg_color)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment