Skip to content

Instantly share code, notes, and snippets.

@PCJohn
Created January 14, 2019 07:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PCJohn/ed503fad5c8b43e3da29ffc60b1313d3 to your computer and use it in GitHub Desktop.
Save PCJohn/ed503fad5c8b43e3da29ffc60b1313d3 to your computer and use it in GitHub Desktop.
Display graphs with networkx
from __future__ import division
import numpy as np
from matplotlib import pyplot as plt
import networkx as nx
import copy
class Graph:
def __init__(self,adj,labels):
self.adj = adj
self.char_list = labels
self.lab = dict(enumerate(labels))
self.edge_labels = {(n1,n2) : adj[n1,n2] for n1,v1 in enumerate(labels) for n2,v2 in enumerate(labels) if (adj[n1,n2] > 0)}
self.G = nx.from_numpy_matrix(adj)
#Return the most central nodes. These have
def most_central(self,F=1,cent_type='betweenness'):
if cent_type == 'betweenness':
ranking = nx.betweenness_centrality(self.G).items()
elif cent_type == 'closeness':
ranking = nx.closeness_centrality(self.G).items()
elif cent_type == 'eigenvector':
ranking = nx.eigenvector_centrality(self.G).items()
elif cent_type == 'harmonic':
ranking = nx.harmonic_centrality(self.G).items()
elif cent_type == 'katz':
ranking = nx.katz_centrality(self.G).items()
elif cent_type == 'load':
ranking = nx.load_centrality(self.G).items()
elif cent_type == 'degree':
ranking = nx.degree_centrality(self.G).items()
ranks = [r for n,r in ranking]
cent_dict = dict([(self.lab[n],r) for n,r in ranking])
m_centrality = sum(ranks)
if len(ranks) > 0:
m_centrality = m_centrality/len(ranks)
#Create a graph with the nodes above the cutoff centrality- remove the low centrality nodes
thresh = F*m_centrality
lab = {}
for k in self.lab:
lab[k] = self.lab[k]
g = Graph(self.adj.copy(),self.char_list)
for n,r in ranking:
if r < thresh:
g.G.remove_node(n)
del g.lab[n]
return (cent_dict,thresh,g)
#Displays the graph visualization
def show(self,path=None):
pos = nx.spring_layout(self.G)
nx.draw_networkx_nodes(self.G,pos,alpha=0.5)
edge_weights = dict([((u,v),int(d['weight'])) for u,v,d in self.G.edges(data=True)])
nx.draw_networkx_labels(self.G,pos,self.lab,alpha=0.5)
nx.draw_networkx_edges(self.G,pos,alpha=0.5)
nx.draw_networkx_edge_labels(self.G,pos,edge_labels=self.edge_labels,alpha=1)
if path is None:
plt.show()
else:
plt.savefig(path)
if __name__ == '__main__':
adj = np.array([[0,1,1],[1,0,0],[0,0,1]])
node_labels = ['v1','v2','v3']
g = Graph(adj,node_labels)
#g.show(path=None) # set path to None to display
g.show(path='mygraph.png') # set path to save
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment