Created
December 29, 2020 03:24
-
-
Save alexklapheke/249016549d6f24c085f76e82a0b8247d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import StrMethodFormatter | |
darkgray = "#4c4c4c" | |
def barh( | |
series, | |
label_fmt, | |
tick_fmt=None, | |
color="C0", | |
xlim=None, | |
xticks=None, | |
ylabels=None, | |
error=None, | |
title=None, | |
stacked=False, | |
ax=None, | |
): | |
"""Horizontal bar plot with error bars and bar labels.""" | |
# Set up figure | |
if not ax: | |
fig, ax = plt.subplots(figsize=(4, series.shape[0] / 3)) | |
# Draw bars | |
series = series[::-1] | |
if stacked: | |
series.plot.barh(ax=ax, stacked=True, color=color, width=0.9) | |
else: | |
bars = ax.barh(series.index.astype(str), series, height=0.9, color=color) | |
# Draw error bars | |
if error is not None: | |
error = error[::-1] | |
ax.errorbar( | |
series, | |
series.index.astype(str), | |
xerr=( # 95% conf., but don't go below 0 | |
np.minimum(error * 1.96, series), | |
error * 1.96, | |
), | |
fmt="none", | |
color=darkgray, | |
capsize=3, | |
) | |
# Format axes | |
if title: | |
ax.set_title(title, loc="left") | |
if ylabels: | |
ax.set_yticklabels(ylabels) | |
if not tick_fmt: | |
tick_fmt = label_fmt | |
if xlim: | |
ax.set_xlim((0, xlim)) | |
if xticks: | |
ax.set_xticks(xticks) | |
ax.xaxis.set_major_formatter(StrMethodFormatter(tick_fmt)) | |
# Draw labels | |
offsets = ( | |
np.minimum(xlim / 2.5, error * 2) if error is not None else [0] * len(bars) | |
) | |
for offset, bar in zip(offsets, bars): | |
ax.text( | |
x=0 | |
if error is None and bar.get_width() == 0 | |
else bar.get_width() + ax.get_xlim()[1] / 33 + offset, | |
y=bar.get_y() + 0.3, | |
s=label_fmt.format(x=bar.get_width()), | |
) | |
return ax | |
def barh_stacked( | |
df, | |
colors, | |
label_fmt, | |
xticks=None, | |
yticks=None, | |
title=None, | |
ylabels=None, | |
labels=False, | |
tick_fmt=None, | |
ax=None, | |
): | |
"""Stacked horizontal bar plot with error bars and bar labels.""" | |
# Set up figure | |
if not ax: | |
fig, ax = plt.subplots(figsize=(4, df.shape[0] / 3)) | |
# Draw bars | |
df = df[::-1] | |
df.plot.barh(ax=ax, stacked=True, color=colors, width=0.9) | |
# Format axes | |
if title: | |
ax.set_title(title, loc="left") | |
if yticks: | |
ax.set_yticks(yticks[::-1]) | |
if ylabels: | |
ax.set_yticklabels(ylabels) | |
if xticks: | |
ax.set_xticks(xticks) | |
if not tick_fmt: | |
tick_fmt = label_fmt | |
ax.xaxis.set_major_formatter(StrMethodFormatter(tick_fmt)) | |
ax.set_ylabel("") | |
# Draw labels | |
if labels: | |
label_series = df.sum(axis=1)[::-1] | |
for index, label in zip(label_series.index, label_series): | |
ax.text( | |
label + ax.get_xlim()[1] / 33, | |
index + 0.3, | |
label_fmt.format(x=label), | |
# color=darkgray | |
) | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment