Skip to content

Instantly share code, notes, and snippets.

@mattpitkin
Last active February 20, 2024 16:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattpitkin/b9f5b7068e6b400c9fd49ba09a42fccb to your computer and use it in GitHub Desktop.
Save mattpitkin/b9f5b7068e6b400c9fd49ba09a42fccb to your computer and use it in GitHub Desktop.
Violin plot with gradient colour fill

In answer to this StackOverflow question, here is a method (based heavily on this answer) to create a violin plot with a gradient colour fill based on the density of the samples (i.e., width of the violin):

from matplotlib import pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import PathPatch
import matplotlib as mpl

import numpy as np

from scipy.stats import gaussian_kde


# create KDE
def kde(values, npoints=100):
    vals = np.linspace(values.min(), values.max(), npoints, endpoint=True)

    # generate and evaluate the KDE
    return gaussian_kde(values)(vals)


data = [np.random.normal(loc=i, scale=1, size=(100,)) for i in range(3)]

fig, ax = plt.subplots()
violins = ax.violinplot(data, showextrema=False)

# get colour map
cmap = mpl.colormaps["Blues"]

# number of point to evaluate a Gaussian KDE
npoints = 100

ymin, ymax = ax.get_ylim()
xmin, xmax = ax.get_xlim()

# loop over each violin
for i, violin in enumerate(violins["bodies"]):
    # get extent of violin
    extent = violin.get_datalim(ax.transData)

    # add new patch surrounding the violin
    path = Path(violin.get_paths()[0].vertices)
    patch = PathPatch(path, facecolor="none", edgecolor="none")
    violin.set_visible(False)  # make original violin invisible
    ax.add_patch(patch)

    # create image to show the gradient
    cvalues = np.atleast_2d(kde(data[i], npoints)).T
    img = ax.imshow(
        cvalues,
        origin="lower",
        extent=[extent.xmin, extent.xmax, extent.ymin, extent.ymax],
        cmap=cmap,
        aspect="auto",
        clip_path=patch,
    )

# reset axes ranges
ax.set_xlim([xmin, xmax])
ax.set_ylim([ymin, ymax])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment