Skip to content

Instantly share code, notes, and snippets.

@timedcy
Created February 9, 2018 09:23
Show Gist options
  • Save timedcy/e32333a4fdfea424281141d8b1f238fc to your computer and use it in GitHub Desktop.
Save timedcy/e32333a4fdfea424281141d8b1f238fc to your computer and use it in GitHub Desktop.
"""
Created on 2017-10-25
@author: timedcy@gmail.com
"""
import numpy as np
import numpy.random as npr
class AliasSample(object):
__slots__ = ('K', 'q', 'J')
def __init__(self, probs):
"""
Compute utility lists for non-uniform sampling from discrete distributions.
Refer to
https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
for details
"""
K = len(probs)
q = np.zeros(K)
J = np.zeros(K, dtype=np.int)
# Sort the data into the outcomes with probabilities that are larger and smaller than 1/K.
smaller, larger = [], []
for k, prob in enumerate(probs):
q[k] = K * prob
if q[k] < 1.0:
smaller.append(k)
else:
larger.append(k)
# Loop though and create little binary mixtures that appropriately allocate the larger outcomes over the overall
# uniform mixture.
while len(smaller) > 0 and len(larger) > 0:
small, large = smaller.pop(), larger.pop()
J[small] = large
q[large] = q[large] - (1.0 - q[small])
if q[large] < 1.0:
smaller.append(large)
else:
larger.append(large)
self.K = K
self.q = q
self.J = J
def choose(self):
k = int(np.floor(npr.rand() * self.K))
return k if npr.rand() < self.q[k] else self.J[k]
if __name__ == '__main__':
K = 10
N = 10000
probs = npr.dirichlet(np.ones(K), 1).ravel()
sampler = AliasSample(probs)
X = np.zeros(N)
for nn in range(N):
X[nn] = sampler.choose()
import collections
cnt = [x[1] for x in sorted(collections.Counter(X.astype(int).ravel()).most_common(), key=lambda x: x[0])]
s = sum(cnt)
cnt = [float(c) / s for c in cnt]
print('probs', probs[:10])
print('cnt ', cnt[:10])
print('cnt vs sampled:')
for a, b in zip(probs[:10], cnt[:10]):
print('{:.4f}\t{:.4f}'.format(a, b))
print('X', X[:10])
print('probs', np.asarray(probs).argsort()[:10])
print('cnt ', np.asarray(cnt).argsort()[:10])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment