Skip to content

Instantly share code, notes, and snippets.

@bwengals
Created February 14, 2017 02:42
Show Gist options
  • Save bwengals/9fcf4615b1f47a405aeaaa6e725b9dcb to your computer and use it in GitHub Desktop.
Save bwengals/9fcf4615b1f47a405aeaaa6e725b9dcb to your computer and use it in GitHub Desktop.
2d credible intervals in matplotlib from mcmc samples
def hpdplot2d(x_samples, y_samples, bins, perc=0.95, extent=None):
if extent is None:
xmin, xmax = np.min(x_samples), np.max(x_samples)
ymin, ymax = np.min(y_samples), np.max(y_samples)
else:
xmin, xmax, ymin, ymax = extent
x_flat = np.linspace(xmin, xmax, bins)
y_flat = np.linspace(ymin, ymax, bins)
x,y = np.meshgrid(x_flat, y_flat)
grid_coords = np.append(x.reshape(-1,1), y.reshape(-1,1),axis=1)
points = np.vstack((x_samples, y_samples))
kde = sp.stats.kde.gaussian_kde(points)
z = kde(grid_coords.T).reshape(bins, bins)
z = z/np.max(z)
errfunc = lambda zv: np.square(np.sum(z[z > zv])/np.sum(z) - perc)
result = sp.optimize.minimize_scalar(errfunc, method="bounded", bounds=(z.min(), z.max()))
return z, result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment