Skip to content

Instantly share code, notes, and snippets.

@pierrelux
Created August 28, 2015 23:43
Show Gist options
  • Save pierrelux/d2e725c3dca92d0487d7 to your computer and use it in GitHub Desktop.
Save pierrelux/d2e725c3dca92d0487d7 to your computer and use it in GitHub Desktop.
Inverse transform sampling
class InverseTransform:
def __init__(self, probs, rng=None):
if np.sum(probs) < 1.:
raise ValueError('Must sum up to 1')
self.rng = rng
if rng is None:
self.rng = np.random.RandomState()
self.cdf = np.cumsum(probs)
def __call__(self, size=None):
return np.argmax(np.greater(self.cdf, self.rng.rand(size)[:, np.newaxis]), axis=1)
@timvieira
Copy link

You should use searchsorted rather than argmax in line 12! It works because cdfs are sorted. That will give you a runtime of size * log |probs| rather than size * |probs|!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment