Skip to content

Instantly share code, notes, and snippets.

@bagrow
Created December 10, 2015 13:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bagrow/7a99ecf6891099f47e01 to your computer and use it in GitHub Desktop.
Save bagrow/7a99ecf6891099f47e01 to your computer and use it in GitHub Desktop.
Improve the placement of labels on a busy scatterplot
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# annotate_with_graphlayout.py
# Jim Bagrow
# Last Modified: 2015-12-09
"""
Rough improvement of scatter point label placement using
force-directed graph layout.
See also Stack Overflow: http://bit.ly/1NXgB3o
"""
import sys, os
import matplotlib.pyplot as plt
import networkx as nx
from itertools import combinations
import numpy as np
# these constants likely require tuning for different data:
dx,dy = 0.05,0.05 # initial offset of labels from points
dX,dY = 0.8, 0.8 # expand the frame of the plot to make room for stretched labels
WEIGHT = 10.0 # strength of links, how far from points can text be?
K = 0.5 # preferred distance between points
label_repulsive_weight = 0.001 # labels push away from each other
# some fake data:
N = 30
X = np.random.randn(N)
Y = X*0.5 + np.random.randn(N)
C = np.random.rand(N)
# build the network:
d_nodes, t_nodes = [], []
G = nx.Graph()
node2coord = {}
for i in range(N):
x, y = X[i], Y[i]
d_str = "d%i" % i
t_str = "t%i" % i
d_nodes.append(d_str)
t_nodes.append(t_str)
G.add_edge(d_str, t_str, weight=WEIGHT)
node2coord[d_str] = (x, y)
node2coord[t_str] = (x+dx, y+dy)
# "t" nodes are self-repulsive:
for ni,nj in combinations(t_nodes,2):
G.add_edge(ni,nj, weight=-label_repulsive_weight*WEIGHT)
# compute new layout, only "t" nodes can move:
node2coord_springs = nx.spring_layout(G, k=K, pos=node2coord, fixed=d_nodes)
def draw_and_label(coords, arrows=False):
ax = plt.gca()
ax.scatter(X, Y, c=C, s=C*200)
arrowprops = None
if arrows:
arrowprops = dict(facecolor='black', shrink=0.05, width=0.5, headwidth=0.5)
for i in range(N):
d_str = "d%i" % i
t_str = "t%i" % i
ax.annotate(t_str,
xy=coords[d_str],
xytext=coords[t_str],
arrowprops=arrowprops,
)
ax.set_xlim(X.min()-dX, X.max()+dX)
ax.set_ylim(Y.min()-dY, Y.max()+dY)
plt.figure(figsize=(12,6))
plt.subplot(121)
draw_and_label(node2coord)
plt.subplot(122)
draw_and_label(node2coord_springs, arrows=True)
plt.show()
@rookiedk
Copy link

rookiedk commented Mar 7, 2019

This is pretty cool! thank-you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment