Skip to content

Instantly share code, notes, and snippets.

@pdavidsonFIA
Created January 24, 2021 11:40
Show Gist options
  • Save pdavidsonFIA/770ddb4b06dc61bdd3e9c67d484030a9 to your computer and use it in GitHub Desktop.
Save pdavidsonFIA/770ddb4b06dc61bdd3e9c67d484030a9 to your computer and use it in GitHub Desktop.
Waterfall chart from series data
import matplotlib, matplotlib.pyplot as plt, matplotlib.ticker as tick
import seaborn as sns
def get_waterfall_colours(df, stocks=None):
"""Returns color palette for drawing waterfall bars"""
default_colours = sns.color_palette("tab10")
index_rgb = [3, 2, 0,5] # Red, green, blue, brown
palette_subset = [default_colours[i] for i in index_rgb]
colour_palette = []
for index, value in df.items():
if index in stocks:
if value >= 0:
idx = 2 # Blue base
else:
idx = 3
elif value >= 0:
idx = 1 # Green uptick
else:
idx = 0 # Red downtick
colour_palette.append(palette_subset[idx])
return colour_palette
def draw_waterfall(df, ax=None, stocks=None, title=None, filename=None, show=True, scaling = 1):
"""Draws waterfall chart
Args:
df: Series data.
ax(optional): Axes. The current axes is used by default.
stocks(optional): List of names representing balances.
title(optional): Graph title string.
Example:
import pandas as pd
df = pd.Series({ "Premiums": 20,
"Claims": -10,
"Expenses": -5})
df['Total']=df.sum()
draw_waterfall(df, title='demo', stocks=['Total'], show=True)
"""
if show:
matplotlib.use('TkAgg')
else:
matplotlib.use('Agg')
if ax is None:
plt.figure()
ax = plt.gca()
if stocks is None:
stocks = []
cols = len(df)
bottom = df.copy()
bottom.iloc[:] = 0
tops = df.copy()
tops.iloc[:] = 0
for c in range(cols):
if c == 0 or df.index[c] in stocks:
bottom.iloc[c] = 0
tops.iloc[c] = max(0, df.iloc[c])
else:
bottom.iloc[c] = bottom.iloc[c - 1] + df.iloc[c - 1]
tops.iloc[c] = tops.iloc[c - 1] + min(0, df.iloc[c - 1]) + max(0, df.iloc[c])
palette = get_waterfall_colours(df, stocks)
xlabel = [idx for idx in df.index]
ax.set(ylim=(min(0,1.1*min(df.cumsum())), 1.1*max(df)))
ax.set_ylabel('Amount €m', loc='center')
f = lambda x, y: "{:,.1f}".format(x/scaling)
ax.yaxis.set_major_formatter(tick.FuncFormatter(f))
ax = sns.barplot(data=[[i] for i in df], bottom=list(bottom),
palette=palette,
ax=ax)
ax.set_xticklabels(labels=xlabel, rotation='vertical')
for index, value in df.items():
ax.text(float(df.index.get_loc(index)), tops[index], "{:,.2f}".format(value/scaling), color='black', ha="center")
if title:
ax.set_title(title)
ax.get_figure().tight_layout()
if filename is not None:
plt.savefig(filename)
if show:
plt.show()
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment