Skip to content

Instantly share code, notes, and snippets.

@rejsmont
Last active January 5, 2021 16:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rejsmont/e187dce64375ba75ada7d706ddf89fa8 to your computer and use it in GitHub Desktop.
Save rejsmont/e187dce64375ba75ada7d706ddf89fa8 to your computer and use it in GitHub Desktop.
Create a swarm in native matplotlib of out any scatter plot series
# Based on swarmplot code from seaborn - slightly adapted to take either a PathCollection or a list of PathCollections,
# thus enabling a hybrid swarm of scatter plots using different markers.
# Source: https://github.com/mwaskom/seaborn/blob/f1852b584c3edb750cfc0ee7c6cf6b34453ca5c9/seaborn/categorical.py
# License and copyright: https://github.com/mwaskom/seaborn/blob/master/LICENSE
# Comments from the source removed for brevity
def swarmify(ax, swarm, width, **kws):
def could_overlap(xy_i, swarm, d):
_, y_i = xy_i
neighbors = []
for xy_j in reversed(swarm):
_, y_j = xy_j
if (y_i - y_j) < d:
neighbors.append(xy_j)
else:
break
return np.array(list(reversed(neighbors)))
def position_candidates(xy_i, neighbors, d):
candidates = [xy_i]
x_i, y_i = xy_i
left_first = True
for x_j, y_j in neighbors:
dy = y_i - y_j
dx = np.sqrt(max(d ** 2 - dy ** 2, 0)) * 1.05
cl, cr = (x_j - dx, y_i), (x_j + dx, y_i)
if left_first:
new_candidates = [cl, cr]
else:
new_candidates = [cr, cl]
candidates.extend(new_candidates)
left_first = not left_first
return np.array(candidates)
def first_non_overlapping_candidate(candidates, neighbors, d):
if len(neighbors) == 0:
return candidates[0]
neighbors_x = neighbors[:, 0]
neighbors_y = neighbors[:, 1]
d_square = d ** 2
for xy_i in candidates:
x_i, y_i = xy_i
dx = neighbors_x - x_i
dy = neighbors_y - y_i
sq_distances = np.power(dx, 2.0) + np.power(dy, 2.0)
good_candidate = np.all(sq_distances >= d_square)
if good_candidate:
return xy_i
raise Exception('No non-overlapping candidates found. '
'This should not happen.')
def beeswarm(orig_xy, d):
midline = orig_xy[0, 0]
swarm = [orig_xy[0]]
for xy_i in orig_xy[1:]:
neighbors = could_overlap(xy_i, swarm, d)
candidates = position_candidates(xy_i, neighbors, d)
offsets = np.abs(candidates[:, 0] - midline)
candidates = candidates[np.argsort(offsets)]
new_xy_i = first_non_overlapping_candidate(candidates, neighbors, d)
swarm.append(new_xy_i)
return np.array(swarm)
def add_gutters(points, center, width):
half_width = width / 2
low_gutter = center - half_width
off_low = points < low_gutter
if off_low.any():
points[off_low] = low_gutter
high_gutter = center + half_width
off_high = points > high_gutter
if off_high.any():
points[off_high] = high_gutter
gutter_prop = (off_high + off_low).sum() / len(points)
return points
default_lw = mpl.rcParams["patch.linewidth"]
default_s = mpl.rcParams["lines.markersize"] ** 2
lw = kws.get("linewidth", kws.get("lw", default_lw))
s = kws.get("size", kws.get("s", default_s))
dpi = ax.figure.dpi
d = (np.sqrt(s) + lw*2) * (dpi / 72)
if not isinstance(swarm, Iterable):
swarm = [swarm]
offsets = np.concatenate([points.get_offsets() for points in swarm])
center = offsets[0, 0]
sort = np.argsort(offsets[:,1])
orig_xy = ax.transData.transform(offsets[sort])
new_xy = beeswarm(orig_xy, d)
new_x, new_y = ax.transData.inverted().transform(new_xy).T
add_gutters(new_x, center, width)
new_offsets = np.c_[new_x, new_y]
new_offsets = new_offsets[np.argsort(sort)]
offset = 0
for points in swarm:
p_offsets = points.get_offsets()
length = p_offsets.shape[0]
p_new_offsets = new_offsets[offset:offset+length, :]
points.set_offsets(p_new_offsets)
offset = offset + length
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment