Skip to content

Instantly share code, notes, and snippets.

@fredyr
Created July 6, 2020 08:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fredyr/d54a2e4731c2e3175b28b47069535da0 to your computer and use it in GitHub Desktop.
Save fredyr/d54a2e4731c2e3175b28b47069535da0 to your computer and use it in GitHub Desktop.
"""Takes a CSV distance matrix and renders it as a MST."""
import sys
import matplotlib.pyplot as plt
import networkx as nx
import pandas
from scipy.sparse.csgraph import minimum_spanning_tree
def read_distmat(f):
return pandas.read_csv(f, sep=',', index_col=0)
def mst(df, thres=None):
dst = df.values
mst = minimum_spanning_tree(dst)
G = nx.from_scipy_sparse_matrix(mst)
# Remove edges above threshold to get clustering
if thres:
edges = [(i, j) for i, j in G.edges() if G[i][j]['weight'] > thres]
for edge in edges:
G.remove_edge(*edge)
layout = nx.spring_layout(G)
nx.draw(G, pos=layout, node_size=1024, node_color='skyblue')
labels = {idx: label for idx, label in enumerate(df.columns.values)}
nx.draw_networkx_labels(G, pos=layout, labels=labels, font_size=6)
edge_labels = {(u, v): int(d['weight']) for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos=layout, edge_labels=edge_labels, font_size=6)
plt.draw()
plt.show()
# TODO: values is deprecated, and replaced by .to_numpy(), should upgrade my
# pandas version
f = sys.argv[1]
mst(read_distmat(f))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment