Skip to content

Instantly share code, notes, and snippets.

@tulerpetontidae
Last active March 12, 2023 20:48
Show Gist options
  • Save tulerpetontidae/09a1a8bb61706c7514f33dc50a6917fd to your computer and use it in GitHub Desktop.
Save tulerpetontidae/09a1a8bb61706c7514f33dc50a6917fd to your computer and use it in GitHub Desktop.
Density scatter plot with marginals
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import matplotlib.ticker as plticker
rng = np.random.default_rng(42)
def generate_dist(mu=[0, 0], sigma1=1, sigma2=2, rho=0.5, n_points=250):
"""
Generate sample from bivariate normal dist
:return: x, y lists of coordinates
"""
x, y = rng.multivariate_normal(mu,
[[sigma1 ** 2, rho * sigma1 * sigma2],
[rho * sigma1 * sigma2, sigma2 ** 2]],
n_points).T
return x, y
def kde(x, y, xmin, xmax, ymin, ymax, lim=1e-3):
"""
Compute KDE on the plot space
:param lim: minimal kde value to present, anything lower goes to 0
:return: x grid, y grid, kde evaluated at each point
"""
xx, yy = np.mgrid[xmin:xmax:500j, ymin:ymax:500j]
positions = np.vstack([xx.ravel(), yy.ravel()])
values = np.vstack([x, y])
kernel = st.gaussian_kde(values)
f = np.reshape(kernel(positions).T, xx.shape)
f[f < lim] = 0
return xx, yy, f
def boxplot_annot(x1, x2, y1, y2, text, ax, d=0.1, vert=True):
"""
Put annotation on boxplot
"""
if vert:
ax.plot([x1, x1, x2, x2],
[y1 + d, np.max([y1, y2]) + d * 2, np.max([y1, y2]) + d * 2, y2 + d],
c='k', lw=1)
ax.text((x1 + x2) / 2, np.max([y1, y2]) + d * 4, text,
horizontalalignment='center',
verticalalignment='center')
else:
ax.plot([y1 + d, np.max([y1, y2]) + d * 2, np.max([y1, y2]) + d * 2, y2 + d],
[x1, x1, x2, x2], c='k', lw=1)
ax.text(np.max([y1, y2]) + d * 3, (x1 + x2) / 2, text,
horizontalalignment='left',
verticalalignment='center')
def boxplot(dat, xmin, xmax, annot=None, colours=None, ax=None, vert=True):
"""
Display boxplots of marginal data
"""
pos = (np.arange(len(dat)) + 0.5) * (xmax - xmin) / len(dat)
bplots = []
for i in range(len(dat)):
bplots.append(ax.boxplot(dat[i], positions=[pos[i]],
widths=(xmax - xmin) / len(dat) * 0.6,
vert=vert,
patch_artist=True,
manage_ticks=False))
# it's madness, but i don't know the way to get the state of colour cycler without changing it
if colours is None:
colours = []
for i in range(len(dat)):
sc = ax.scatter([], [])
colours.append([list(sc.get_facecolor()[0][:-1]) + [1]][0])
for bp, colour in zip(bplots, colours):
bp['boxes'][0].set_facecolor(colour)
bp['medians'][0].set_color('k')
bp['medians'][0].set_linewidth(3)
if annot is not None:
xi1, xi2, val = annot
boxplot_annot(pos[xi1], pos[xi2], np.max(dat[xi1]), np.max(dat[xi2]),
val, ax=ax, d=0.05 * ax.get_ylim()[1], vert=vert)
def mainplot(dat, xmin, xmax, ymin, ymax, colours=None, ax=None):
"""
Display 2 dimensional distribution of data points with KDE contours
"""
if ax is None:
ax = plt.gca()
for i, d in enumerate(dat):
if colours is not None:
c = colours[i]
else:
c = None
sc = ax.scatter(*d, s=15, alpha=0.5, c=c)
cl = [list(sc.get_facecolors()[0][:-1]) + [1]]
xx, yy, f = kde(*d, xmin, xmax, ymin, ymax)
ax.contour(xx, yy, f, 5, colors=cl)
def plot_legend(dat, labels, colours=None, ax=None):
"""
Construct and display legends
"""
if ax is None:
ax = plt.gca()
handels = []
for i in range(len(dat)):
if colours is not None:
c = colours[i]
else:
c = None
element, = ax.plot([], [], '-o', c=c, label=labels[i])
handels.append(element)
ax.legend(handles=handels, loc='center', frameon=False)
def set_style(axs, loc_base=10):
"""
Standard style set, removes part of boxes, set locator to a standard base
"""
axs[0, 0].spines.right.set_visible(False)
axs[0, 0].spines.top.set_visible(False)
axs[0, 0].tick_params(bottom=False, top=False,
labelbottom=False)
axs[0, 1].spines.right.set_visible(False)
axs[0, 1].spines.top.set_visible(False)
loc = plticker.MultipleLocator(base=loc_base)
axs[0, 1].xaxis.set_major_locator(loc)
axs[0, 1].yaxis.set_major_locator(loc)
axs[1, 1].spines.top.set_visible(False)
axs[1, 1].spines.right.set_visible(False)
axs[1, 1].tick_params(left=False, right=False,
labelleft=False)
axs[1, 0].axis('off')
# generate data
x1, y1 = generate_dist(mu=[16, 28], sigma1=2, sigma2=2.5, rho=0.7, n_points=30)
x2, y2 = generate_dist(mu=[24, 40], sigma1=3, sigma2=5, rho=0.5, n_points=200)
x3, y3 = generate_dist(mu=[14, 20], sigma1=3, sigma2=3, rho=0.5, n_points=50)
x1 = np.concatenate([x1, x2])
y1 = np.concatenate([y1, y2])
# define plot range
xmin, xmax = 1, 65
ymin, ymax = 1, 65
colours = ['cornflowerblue', 'darkorange'] # None if don't want to specify
data_list = [(x1, y1), (x3, y3)]
# plot the plot
fig, axs = plt.subplots(2, 2, figsize=(6, 4), dpi=150,
sharex=True, sharey=True,
gridspec_kw={'width_ratios': [1, 5],
'height_ratios': [5, 1]})
# set axis style
set_style(axs)
# main plot
mainplot(data_list, xmin, xmax, ymin, ymax, ax=axs[0, 1], colours=colours)
# box plot
boxplot([y1, y3],
ymin, ymax, ax=axs[0, 0], annot=[0, 1, r'$P < 2\times10^{-16}$'],
colours=colours, vert=True)
boxplot([x1, x3],
xmin, xmax, ax=axs[1, 1], annot=[0, 1, r'$P < 2\times10^{-16}$'],
colours=colours, vert=False)
# plot legend
plot_legend(data_list, labels=['Dist 1', 'Dist 2'], ax=axs[1, 0], colours=colours)
# general labels
# fig.suptitle('Plot implemented in pure matplotlib')
axs[1, 1].set_xlabel('Values on x-axis')
axs[0, 0].set_ylabel('Values on y-axis')
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment