Last active
August 29, 2015 14:00
-
-
Save turnersr/c5f305316a54bc872b15 to your computer and use it in GitHub Desktop.
Sampling from a Dirichlet distribution by sampling from a n-sphere and then projecting onto a n-simplex.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def projection_on_to_simplex(sphere_point): | |
# References | |
# Same implmentation as the one used here at https://gist.github.com/daien/1272551 | |
# Efficient Projections onto the .1-Ball for Learning in High Dimensions - http://machinelearning.org/archive/icml2008/papers/361.pdf | |
mu = np.sort(sphere_point)[::-1] | |
mu_sums = np.cumsum(mu) | |
rho = np.nonzero(mu * np.arange(1, sphere_point.shape[0] + 1) > (mu_sums - 1))[0][-1] | |
theta = (mu_sums[rho] - 1) / (rho + 1.0) | |
simplex_point = (sphere_point - theta).clip(min=0) | |
assert np.abs(np.sum(simplex_point) - 1) <= .0000001 | |
return simplex_point | |
def sample_dirichlet(dimension,n_samples): | |
# References | |
# A note on a method for generating points uniformly on n-dimensional spheres - http://dl.acm.org/citation.cfm?id=377946 | |
simplex_points = [] | |
for x in range(0,n_samples): | |
mu, sigma = 0, 0.1 | |
s = np.random.normal(mu, sigma, dimension) | |
sphere_points = s / np.sqrt(np.sum(s ** 2)) | |
multinomial = projection_on_to_simplex(sphere_points) | |
simplex_points.append(multinomial) | |
return simplex_points | |
simplex_points = sample_dirichlet(3,50) | |
print np.all(map(np.sum,simplex_points)) == 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment