Skip to content

Instantly share code, notes, and snippets.

@turnersr
Last active August 29, 2015 14:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save turnersr/c5f305316a54bc872b15 to your computer and use it in GitHub Desktop.
Save turnersr/c5f305316a54bc872b15 to your computer and use it in GitHub Desktop.
Sampling from a Dirichlet distribution by sampling from a n-sphere and then projecting onto a n-simplex.
import numpy as np
def projection_on_to_simplex(sphere_point):
# References
# Same implmentation as the one used here at https://gist.github.com/daien/1272551
# Efficient Projections onto the .1-Ball for Learning in High Dimensions - http://machinelearning.org/archive/icml2008/papers/361.pdf
mu = np.sort(sphere_point)[::-1]
mu_sums = np.cumsum(mu)
rho = np.nonzero(mu * np.arange(1, sphere_point.shape[0] + 1) > (mu_sums - 1))[0][-1]
theta = (mu_sums[rho] - 1) / (rho + 1.0)
simplex_point = (sphere_point - theta).clip(min=0)
assert np.abs(np.sum(simplex_point) - 1) <= .0000001
return simplex_point
def sample_dirichlet(dimension,n_samples):
# References
# A note on a method for generating points uniformly on n-dimensional spheres - http://dl.acm.org/citation.cfm?id=377946
simplex_points = []
for x in range(0,n_samples):
mu, sigma = 0, 0.1
s = np.random.normal(mu, sigma, dimension)
sphere_points = s / np.sqrt(np.sum(s ** 2))
multinomial = projection_on_to_simplex(sphere_points)
simplex_points.append(multinomial)
return simplex_points
simplex_points = sample_dirichlet(3,50)
print np.all(map(np.sum,simplex_points)) == 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment