Skip to content

Instantly share code, notes, and snippets.

@camriddell
Last active May 24, 2024 15:06
Show Gist options
  • Save camriddell/cf787bc25296caf8bfd83bb7c915cfc8 to your computer and use it in GitHub Desktop.
Save camriddell/cf787bc25296caf8bfd83bb7c915cfc8 to your computer and use it in GitHub Desktop.
Create a bump chart in Python using matplotlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
def sigmoid(xs, ys, smooth=8, n=100):
"""Interpolates sigmoid function between x & y coordinates
Parameters
----------
xs, ys: ndarray[2]
arrays must be of shape (2,N) where xs[0] are each of the starting
positions for the x-values and xs[1] are the stopping points.
smooth: int
steepness of sigmoid function slope. Doesn’t look great for values less than 8.
n: int
number of points to interpolate to.
Returns
-------
tuple (xs, ys)
The smoothed & interpolated x/y values evaluated across the inputs.
"""
(x_from, x_to), (y_from, y_to) = xs, ys
xs = np.linspace(-smooth, smooth, num=n)[:, None]
ys = np.exp(xs) / (np.exp(xs) + 1)
return (
((xs + smooth) / (smooth * 2) * (x_to - x_from) + x_from),
(ys * (y_to - y_from) + y_from)
)
def sigmoid_pairwise(xs, ys, smooth=8, n=100):
"""Interpolates sigmoid function between every pair of xs & ys.
xs = [0, 1, 2, 3]
ys = [2, 5, 3, 7]
will interpolate between:
- [0, 1], [2, 5]
- [1, 2], [5, 3]
- [2, 3], [3, 7]
Parameters
----------
xs, ys: array_like
Both inputs should be 1-d arrays and have identical length.
smooth: int
see sigmoid func
n: int
see sigmoid func
Returns
-------
tuple (xs, ys)
The smoothed & interpolated x/y values evaluated across the inputs.
"""
xs = np.lib.stride_tricks.sliding_window_view(xs, 2)
ys = np.lib.stride_tricks.sliding_window_view(ys, 2)
interp_x, interp_y = sigmoid(xs.T, ys.T, smooth=smooth, n=n)
return interp_x.T.flat, interp_y.T.flat
df = pd.DataFrame({
'year': [*range(2019, 2022)] * 3,
'company': np.repeat(['A', 'B', 'C'], 3),
'revenue': [100, 200, 300, 150, 250, 100, 200, 300, 400],
})
plt.rc('font', size=14)
plt.rc('axes.spines', right=False, top=False)
fig, ax = plt.subplots(figsize=(9, 6))
for company, group in df.groupby('company'):
group = group.sort_values('year')
xs, ys = group['year'], group['revenue']
interp_x, interp_y = sigmoid_pairwise(xs, ys)
line, = ax.plot(interp_x, interp_y, lw=3)
ax.scatter(xs, ys, s=100, color=line.get_color())
text = ax.annotate(
f'Company {company.title()}',
xy=(1, ys.iloc[-1]), xycoords=(ax.transAxes, ax.transData),
xytext=(5, 0), textcoords='offset points',
color=line.get_color(),
va='center',
)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.set_ylabel('Revenue')
ax.margins(x=.02)
ax.set_title('Revenue Bump Chart', size='x-large')
fig.tight_layout()
fig.savefig('bumpchart.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment