Created
September 18, 2021 18:46
-
-
Save jkapila/7c103c35e942e7d9eb5a4c195ab9717e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# importing required libraries | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
import networkx as nx | |
# function to plot bayesian network | |
def plot_bn_graph(G,title='',label_fontsize = 14,node_color='b',node_size=10,figsize=(12,10),position=False,layout='circular',cmap='YlOrRd'): | |
""" | |
Plotting Network diagram using Networkx and Matplotlib only. | |
:param G: A Networkx Digraph havinf nodes as variables and edges and relationship | |
:param title: Title for plot. Default to "Structure Plot" | |
:param label_fontsize: Fontzise of the variable lables on the plot. | |
:param node_color: Color of the node, Defaults to "blue" | |
:param node_size: multipler for size of node. Defaults to 10. | |
:param figsize: Figure size fo the Plot shoudl be passed as tuple (w,h) | |
:param position: If 'pos' is an attribute of graphs then use it to specify coordinate of node on the graph. | |
Default to false. If set to True, 'layout' has no impact. | |
:param layout: 3 type of layouts avaialble "circular", "spring" and "spectral. Defaults to "Circular" | |
:param cmap: A color gradient used to draw edges and is based on weights. Can take any matplotlib colormap. | |
Example Usage: | |
G = nx.DiGraph() | |
nodes = ['Intelligence','Course','Grades','SAT','Job'] | |
positions = [(4,3),(2,3),(3,2),(5,2),(3,1)] | |
for nm,pos in zip(nodes,positions): | |
G.add_node(nm,pos=pos) | |
G.add_edge('Course','Grades',weight=0.5) | |
G.add_edge('Intelligence','Grades',weight=0.5) | |
G.add_edge('Intelligence','SAT',weight=1) | |
G.add_edge('Grades','Job',weight=1.5) | |
plot_bn_graph(G,title='A basic Student Bayesian Network plot',position=True,node_size=200,label_fontsize=16,node_color='skyblue',cmap='Wistia') | |
""" | |
fig, ax = plt.subplots(figsize = figsize) | |
cmap = plt.get_cmap(cmap) | |
bg_color = '#183861' | |
fg_color = 'white' | |
# set title plus title color | |
ax.set_title('Structure Plot' if len(title)==0 else title, color=fg_color, fontsize=16) | |
# set figure facecolor | |
ax.patch.set_facecolor(bg_color) | |
# layouts | |
layouts = {'spring':nx.layout.spring_layout, | |
'circular':nx.layout.circular_layout, | |
'spectral':nx.layout.spectral_layout} | |
if position: | |
pos=nx.get_node_attributes(G,'pos') | |
elif layout not in layouts.keys(): | |
pos = nx.layout.planar_layout(G) | |
else: | |
pos = layouts[layout](G) | |
node_sizes = [3 + node_size * len(G) for _ in range(len(G))] | |
edge_weights = G.edges(data="weight") | |
M = len(edge_weights) | |
if sum([x[2] is None for x in edge_weights]) == M: | |
edge_colors = [5 for _ in edge_weights] | |
edge_alphas = [1 for _ in edge_weights] | |
print('No weights in edges') | |
else: | |
edge_colors = [np.nan if x[2] is None else abs(x[2]) for x in edge_weights] | |
edge_weights = [0.85 for _ in edge_weights] | |
minw, maxw = min(edge_weights),max(edge_weights) | |
edge_alphas = [ (x - minw + 0.01) / (maxw - minw + 0.01) for x in edge_weights] | |
# M = G.number_of_edges() | |
# edge_colors = range(2, M + 2) | |
# edge_alphas = [(5 + i) / (M + 4) for i in range(M)] | |
nodes = nx.draw_networkx_nodes(G, pos, ax=ax, node_size=node_sizes, | |
node_color=node_color) | |
edges = nx.draw_networkx_edges(G, pos, ax=ax, node_size=node_sizes, arrowstyle="->", | |
arrowsize=12, edge_color=edge_colors, edge_cmap=cmap, width=3,) | |
labels = nx.draw_networkx_labels(G, pos, ax=ax, font_size=label_fontsize, font_color=fg_color, | |
verticalalignment ='bottom') | |
# set alpha value for each edge | |
for i in range(M): | |
edges[i].set_alpha(edge_alphas[i]) | |
# the colorbar and its properties | |
pc = mpl.collections.PatchCollection(edges, cmap=cmap) | |
pc.set_array(edge_colors) | |
cb = plt.colorbar(pc,ax=ax) | |
# set colorbar label plus label color | |
cb.set_label(label='Weights',weight='bold', color=fg_color) | |
# set colorbar tick color | |
cb.ax.yaxis.set_tick_params(color=fg_color) | |
# set colorbar edgecolor | |
cb.outline.set_edgecolor(fg_color) | |
# set colorbar ticklabels | |
plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color=fg_color) | |
gc = plt.gca() | |
gc.set_axis_off() | |
fig.set_facecolor(bg_color) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment