Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lebedov/5c6662bf4845425a70c09f98e12f85ab to your computer and use it in GitHub Desktop.
Save lebedov/5c6662bf4845425a70c09f98e12f85ab to your computer and use it in GitHub Desktop.
Relabel some of the nodes of a NetworkX graph with IDs sorted according to the order of some attribute such that only the nodes that have that attribute are relabeled.
import networkx as nx
def partly_relabel_by_sorted_attr(g_old, select_attr, select_attr_vals, sort_attr):
"""
Relabel nodes of NetworkX graph such that the new IDs are
sorted according to the order of some attribute such that
only the nodes that have that attribute are relabeled.
Parameters
----------
g_old : networkx.MultiDiGraph
NetworkX graph.
select_attr : object
Attribute to examine to determine which nodes
to relabel. If a node does not have this attribute, it is
not relabeled.
select_attr_vals : list
If `select_attr` is set to any of these values, it is relabeled. If it
is None, any value is accepted.
sort_attr : object
Attribute on which to sort nodes that are to be relabeled.
Returns
-------
g_new : networkx.MultiDiGraph
Graph containing partly relabeled nodes.
"""
assert isinstance(g_old, nx.MultiDiGraph)
assert isinstance(select_attr_vals, list) or select_attr_vals is None
# Only sort nodes with `select_attr` attribute:
if select_attr_vals is None:
nodes_to_sort = [n for n in g_old.nodes(True) \
if select_attr in n[1]]
else:
nodes_to_sort = [n for n in g_old.nodes(True) \
if select_attr in n[1] and n[1][select_attr] in select_attr_vals]
nodes_to_ignore = [n for n in g_old.nodes(True) if select_attr not in n[1]]
# Sort nodes by value of `sort_attr` attribute:
nodes_sorted = sorted(nodes_to_sort, key=lambda n: n[1][sort_attr])
# First, add the unsorted nodes whose IDs should remain the same:
mapping = {}
for n in nodes_to_ignore:
mapping[n[0]] = n[0]
# Next, create mapping between old IDs and sorted IDs that makes the order
# of the node IDs correspond to that of the `sort_attr` values:
for n, m in zip(nodes_to_sort, nodes_sorted):
mapping[m[0]] = n[0]
return nx.relabel_nodes(g_old, mapping)
g_old = nx.MultiDiGraph()
g_old.add_nodes_from([(0, {'data': 'c'}),
(1, {'data': 'b'}),
(2, {'data': 'e'}),
(3, {'data': 'd'}),
(4, dict()),
(5, {'data': 'a'}),
(6, dict())])
g_old.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 0)])
print '--- old ---'
print g_old.nodes(True)
print g_old.edges()
g_new = partly_relabel_by_sorted_attr(g_old, 'data', None, 'data')
print '--- new ---'
print g_new.nodes(True)
print g_new.edges()
print '-----------'
print 'old/new isomorphic: ', nx.isomorphism.is_isomorphic(g_old, g_new)
# Use hashing to check that the new node IDs and data are correct because
# data structures containing dicts cannot be hashed using stock Python tools:
try:
import chash
assert chash.chash(sorted(g_new.nodes(True))) == \
chash.chash(sorted([(0, {'data': 'a'}), (1, {'data': 'b'}), (2, {'data': 'c'}), (3, {'data': 'd'}),
(4, {}), (5, {'data': 'e'}), (6, {})]))
except ImportError:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment