-
-
Save tboggs/8778945 to your computer and use it in GitHub Desktop.
'''Functions for drawing contours of Dirichlet distributions. | |
MIT License | |
Copyright (c) 2014 Thomas Boggs | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
''' | |
from __future__ import division, print_function | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.tri as tri | |
_corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]]) | |
_AREA = 0.5 * 1 * 0.75**0.5 | |
_triangle = tri.Triangulation(_corners[:, 0], _corners[:, 1]) | |
# For each corner of the triangle, the pair of other corners | |
_pairs = [_corners[np.roll(range(3), -i)[1:]] for i in range(3)] | |
# The area of the triangle formed by point xy and another pair or points | |
tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy))) | |
def xy2bc(xy, tol=1.e-4): | |
'''Converts 2D Cartesian coordinates to barycentric. | |
Arguments: | |
`xy`: A length-2 sequence containing the x and y value. | |
''' | |
coords = np.array([tri_area(xy, p) for p in _pairs]) / _AREA | |
return np.clip(coords, tol, 1.0 - tol) | |
class Dirichlet(object): | |
def __init__(self, alpha): | |
'''Creates Dirichlet distribution with parameter `alpha`.''' | |
from math import gamma | |
from operator import mul | |
self._alpha = np.array(alpha) | |
self._coef = gamma(np.sum(self._alpha)) / \ | |
np.multiply.reduce([gamma(a) for a in self._alpha]) | |
def pdf(self, x): | |
'''Returns pdf value for `x`.''' | |
from operator import mul | |
return self._coef * np.multiply.reduce([xx ** (aa - 1) | |
for (xx, aa)in zip(x, self._alpha)]) | |
def sample(self, N): | |
'''Generates a random sample of size `N`.''' | |
return np.random.dirichlet(self._alpha, N) | |
def draw_pdf_contours(dist, border=False, nlevels=200, subdiv=8, **kwargs): | |
'''Draws pdf contours over an equilateral triangle (2-simplex). | |
Arguments: | |
`dist`: A distribution instance with a `pdf` method. | |
`border` (bool): If True, the simplex border is drawn. | |
`nlevels` (int): Number of contours to draw. | |
`subdiv` (int): Number of recursive mesh subdivisions to create. | |
kwargs: Keyword args passed on to `plt.triplot`. | |
''' | |
from matplotlib import ticker, cm | |
import math | |
refiner = tri.UniformTriRefiner(_triangle) | |
trimesh = refiner.refine_triangulation(subdiv=subdiv) | |
pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)] | |
plt.tricontourf(trimesh, pvals, nlevels, cmap='jet', **kwargs) | |
plt.axis('equal') | |
plt.xlim(0, 1) | |
plt.ylim(0, 0.75**0.5) | |
plt.axis('off') | |
if border is True: | |
plt.triplot(_triangle, linewidth=1) | |
def plot_points(X, barycentric=True, border=True, **kwargs): | |
'''Plots a set of points in the simplex. | |
Arguments: | |
`X` (ndarray): A 2xN array (if in Cartesian coords) or 3xN array | |
(if in barycentric coords) of points to plot. | |
`barycentric` (bool): Indicates if `X` is in barycentric coords. | |
`border` (bool): If True, the simplex border is drawn. | |
kwargs: Keyword args passed on to `plt.plot`. | |
''' | |
if barycentric is True: | |
X = X.dot(_corners) | |
plt.plot(X[:, 0], X[:, 1], 'k.', ms=1, **kwargs) | |
plt.axis('equal') | |
plt.xlim(0, 1) | |
plt.ylim(0, 0.75**0.5) | |
plt.axis('off') | |
if border is True: | |
plt.triplot(_triangle, linewidth=1) | |
if __name__ == '__main__': | |
f = plt.figure(figsize=(8, 6)) | |
alphas = [[0.999] * 3, | |
[5] * 3, | |
[2, 5, 15]] | |
for (i, alpha) in enumerate(alphas): | |
plt.subplot(2, len(alphas), i + 1) | |
dist = Dirichlet(alpha) | |
draw_pdf_contours(dist) | |
title = r'$\alpha$ = (%.3f, %.3f, %.3f)' % tuple(alpha) | |
plt.title(title, fontdict={'fontsize': 8}) | |
plt.subplot(2, len(alphas), i + 1 + len(alphas)) | |
plot_points(dist.sample(5000)) | |
plt.savefig('dirichlet_plots.png') | |
print('Wrote plots to "dirichlet_plots.png".') | |
This is a late reply because - for some reason - I never received notification of your comment. Did you edit the code before running? I ask because the line referenced in your error (line 60) doesn't match what is in my code. Your line is
refiner = tri.trirefine.UniformTriRefiner(_triangle)
which has an extra trirefine
submodule referenced. The line in my code is just
refiner = tri.UniformTriRefiner(_triangle)
Hi,
For those using Python 3. You should do:
from functools import reduce
Thank you for creating the script and helping me build more intuition for the Dirichlet Distribution :-)
It came to my attention that the function xy2bc was incorrect, which resulted in varying inaccuracy over the simplex. While it didn't appear to make a difference for the tolerance used, I've updated this gist with a corrected implementation that uses fractional triangle areas to compute the barycentric coordinates. I also made some minor edits to account for python and matplotlib API changes since the original post.
Really useful - thanks!
Your code is very nice showing how to implement Dirichlet straight from the formula. I ve also tried to experiment it calling scipy.stats.dirichlet library instead. It worked well but we needed to change the tolerance of the xy2bc generator from 1e-4 to 1e-9. Otherwise the assertion of the library code _multivariate.py wont let us to run.
if (np.abs(np.sum(x, 0) - 1.0) > 10e-10).any():
raise ValueError("The input vector 'x' must lie within the normal "
"simplex. but np.sum(x, 0) = %s." % np.sum(x, 0))
Very strange. I have matplotlib 1.3.1 and I get this error: