Skip to content

Instantly share code, notes, and snippets.

@ecolss
Last active September 17, 2023 15:55
Show Gist options
  • Save ecolss/c46d94eddf5390c514e2dc509645d2c4 to your computer and use it in GitHub Desktop.
Save ecolss/c46d94eddf5390c514e2dc509645d2c4 to your computer and use it in GitHub Desktop.
A ridge plot function using matplotlib.
import glob
import numpy as np
from matplotlib.gridspec import GridSpec
import numpy as np
import pylab as pl
from sklearn.neighbors import KernelDensity
def ridge_plot(groups, conf):
assert all([isinstance(el, dict) and all([k in el for k in ["name", "data"]]) for el in groups])
gs = GridSpec(len(groups), 1)
fig = pl.figure(figsize=conf["figsize"])
axes =[]
# x_range = (from, to, n_tick)
x_ticks = np.linspace(*conf["x_range"])
kde = KernelDensity(bandwidth=conf["kde_bw"], kernel='gaussian')
global_color_conf = conf["color"]
for i, group in enumerate(groups):
name = group["name"]
data = group["data"]
assert isinstance(data, np.ndarray) and data.ndim == 1
color_conf = group.get("color", global_color_conf)
# From original data, learn a kde distr.
kde.fit(data.reshape(-1,1))
# Then, can draw samples from the learned kde distr.
logp = kde.score_samples(x_ticks.reshape(-1,1))
# Density can be > 1.
pdensity = np.exp(logp)
# Add plot.
axes.append(fig.add_subplot(gs[i:i+1, 0:]))
axes[-1].plot(x_ticks, pdensity, lw=1)
if color_conf["fill"]:
axes[-1].fill_between(
x_ticks, pdensity,
alpha=color_conf["fill_alpha"], color=color_conf["fill_color"]
)
# X/y limits.
axes[-1].set_xlim(x_ticks[0], x_ticks[-1])
axes[-1].set_ylim(0, pdensity.max() + 0.02)
# Background transparent.
rect = axes[-1].patch
rect.set_alpha(0)
# Remove axis text / ticks.
axes[-1].set_yticklabels([])
axes[-1].tick_params(left=False)
if i == len(groups)-1:
axes[-1].set_xlabel(conf["x_label"], **conf["x_label_kwargs"])
else:
# Remove texts.
axes[-1].set_xticklabels([])
for el in ["top","right","left","bottom"]:
axes[-1].spines[el].set_visible(False)
axes[-1].text(x_ticks[0], 0, name, **conf["y_label_kwargs"])
gs.update(hspace=conf["hspace"])
fig.tight_layout()
pl.show()
@ecolss
Copy link
Author

ecolss commented Sep 10, 2023

An example about how to use the ridge_plot function,

groups = []
for t in range(0, min(len(T), 100), 5):
    for i in dim_idx:
        for j,c in [(0, "lightgreen"), (1, "violet"), (2, "cyan"), (10, "gold")]:
            groups.append({
                "name": f"step-{t}-{i+j}",
                "data": T[t][i+j] if dim==0 else T[t][:,i+j],
                "color": {
                    "fill": True,
                    "fill_color": c,
                    "fill_alpha": 1.0,
                }
            })


conf = {
    "kde_bw": 0.01,  # important to get proper kde results
    "x_range": (-1.1, 1.1, 500),

    "x_label": "weights",
    "x_label_kwargs": {"fontsize": 16, "fontweight": "bold"},
    "y_label_kwargs": {"fontsize": 13, "fontweight": "bold", "ha": "right"},

    "figsize": (16, 30),
    "hspace": -0.7,
    # global color if no individual color provided
    "color": {
        "fill": True,
        "fill_color": "orange",
        "fill_alpha": 1.0,
    }
}

ridge_plot(groups, conf)

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment