Skip to content

Instantly share code, notes, and snippets.

@bendichter
Last active July 15, 2022 18:34
Show Gist options
  • Save bendichter/908d4dbca45bf7cf77583b902c4d54c9 to your computer and use it in GitHub Desktop.
Save bendichter/908d4dbca45bf7cf77583b902c4d54c9 to your computer and use it in GitHub Desktop.
plot a grouped barplot
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Optional
def grouped_barplot(
data,
clabels: List[str],
xlabels: List[str],
gap: float = 0.3,
show_legend: bool = True,
show_bar_labels: bool = True,
ax: Optional[plt.Axes] = None,
):
if ax is None:
_, ax = plt.subplots()
x = np.arange(len(xlabels)) # the label locations
width = (1 - gap)/len(clabels) # the width of the bars
for i, (cdata, clabel) in enumerate(zip(data, clabels)):
rects = ax.bar(x - ((1 - gap) / 2) + i * width, cdata, width, label=clabel)
if show_bar_labels:
ax.bar_label(rects, padding=3)
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_xticks(x, xlabels)
if show_legend:
ax.legend()
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment