Last active
February 11, 2021 16:09
-
-
Save asitang/9fdf8222930390a27285ffced8f20330 to your computer and use it in GitHub Desktop.
A minimalistic Topological layout for DAGs using NetworkX and Plotly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import networkx as nx | |
import plotly.graph_objects as go | |
def plotly_DAG(nx_dag, node_description_dict): | |
""" | |
nx_dag: a networkX directed acyclic graph | |
node_description_dict: a nested dictionary that contains node attributes to show on hover | |
eg. { | |
'node_A':{'property_1':'value_A_1', 'property_2':'value_A_2'}, | |
'node_B':{'property_1':'value_B_1', 'property_2':'value_B_2'} | |
} | |
""" | |
# get the topological sorting order | |
nodes_topo_sorted=list(list(nx.topological_sort(nx_dag))) | |
nodes_to_topo_order={node: order for order, node in enumerate(nodes_topo_sorted)} | |
print('nodes_topo_sorted:',nodes_topo_sorted) | |
print('nodes_to_topo_order:',nodes_to_topo_order) | |
# create a linear layout for the nodes | |
layout=[] | |
for i, node in enumerate(nodes_topo_sorted): | |
layout.append((i+1, 0)) | |
# create a node to layout coordinate mapping | |
node_to_layout_coor={node: layout for node, layout in zip(nodes_topo_sorted, layout)} | |
# add node descriptions | |
nodes_descriptions=[node_description_dict.get(node, {}) for node in nodes_topo_sorted] | |
hovertemplate=[] | |
for node_descriptions in nodes_descriptions: | |
hovertemplate_str="" | |
for k, v in node_descriptions.items(): | |
hovertemplate_str+="<br>"+k+": %{customdata."+k+"}</br>" | |
hovertemplate.append(hovertemplate_str) | |
# create SVG curves for the edges | |
edge_curves=[] | |
max_height=2 # this value controls the maximum height a curve will have | |
min_height=1 | |
side_toggle=1 | |
height_param=10 # set this based on the number of nodes to be plotted | |
for source_node in nodes_topo_sorted: | |
target_nodes=[item[1] for item in nx_dag.out_edges(source_node)] | |
if len(target_nodes)==0: | |
continue | |
# sort the target_nodes based on the topological sort | |
target_nodes=list(sorted(target_nodes, key=lambda x: nodes_to_topo_order[x])) | |
for i, target_node in enumerate(target_nodes): | |
source_node_coor=node_to_layout_coor[source_node] | |
target_node_coor=node_to_layout_coor[target_node] | |
height=(target_node_coor[0]-source_node_coor[0]) * float(max_height-min_height) / height_param | |
if side_toggle<0: | |
height=-height | |
path="M "\ | |
+str(source_node_coor[0])+","+str(source_node_coor[1])\ | |
+" Q "\ | |
+str(source_node_coor[0]+(float(target_node_coor[0]-source_node_coor[0])/2))+","+str(height)+" "\ | |
+str(target_node_coor[0])+","+str(target_node_coor[1]) | |
edge_curves.append(path) | |
side_toggle=-side_toggle | |
# initialize a figure | |
fig=go.Figure() | |
# add nodes in the figure | |
fig.add_trace(go.Scatter( | |
x=[item[0] for item in layout], | |
y=[item[1] for item in layout], | |
customdata=nodes_descriptions, | |
hovertemplate=hovertemplate, | |
mode='markers+text', | |
text=nodes_topo_sorted, | |
textposition=['top center' if i%2==0 else 'bottom center' for i,_ in enumerate(nodes_topo_sorted)], | |
name="", | |
marker=dict( | |
size=[20]* len(layout), | |
color=[1] *len(layout) | |
) | |
)) | |
# add edges to the figure | |
fig.update_layout( | |
shapes=[dict(type="path", path=path, line_color="rgba(135,206,250,0.6)") for path in edge_curves], | |
plot_bgcolor='rgba(0,0,0,0)' | |
) | |
fig.update_xaxes( | |
showgrid= False, # thin lines in the background | |
zeroline= False, # thick line at x=0 | |
visible= False, # numbers below | |
) | |
fig.update_yaxes( | |
showgrid= False, # thin lines in the background | |
zeroline= False, # thick line at x=0 | |
visible= False, # numbers below | |
) | |
fig.update_traces(textfont_size=8) | |
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
networkx==2.4 | |
plotly==4.9.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create a fully connected DAG with 5 nodes | |
a_dag=nx.DiGraph() | |
node_num=5 | |
for i in range(node_num): | |
for j in range(i+1, node_num): | |
a_dag.add_edge('n_'+str(i), 'n_'+str(j)) | |
print('Edges:', a_dag.edges) | |
# plot using the above function | |
plotly_DAG(a_dag, {}) |
Author
asitang
commented
Feb 10, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment