Skip to content

Instantly share code, notes, and snippets.

@johnmeade
Last active July 13, 2019 20:28
Show Gist options
  • Save johnmeade/1ce44243afd5ab4a595da7c4aa129e19 to your computer and use it in GitHub Desktop.
Save johnmeade/1ce44243afd5ab4a595da7c4aa129e19 to your computer and use it in GitHub Desktop.
Fast sampling from arbitrary probability densities in Python
'''
Tools for sampling from arbitrary probability densities.
Requirements:
pip install scipy numpy
John Meade 2019
MIT license
'''
import numpy as np
import scipy.interpolate
from random import random
from multiprocessing import cpu_count, Pool
def pdf2cdf(pdf, xmin, xmax, resolution=100):
'''
Create an approximate CDF function via numercal integration of a PDF.
Args:
pdf (callable): vectorized callable PDF function
xmin (float): left boundary of the approximation domain
xmax (float): right boundary of the approximation domain
resolution (int): accuracy of the approximation
Returns:
cdf (callable): CDF
'''
x = np.linspace(xmin, xmax, resolution)
y = pdf(x)
cs = np.cumsum(y)
# normalize due to boundaries
cs -= cs.min()
cs /= cs.max()
cdf = scipy.interpolate.interp1d(x, cs, kind='cubic', assume_sorted=True)
return cdf
def flatten(xs):
'One-level array flatten operation.'
y = []
for x in xs:
y += x
return y
def chunks(tot, n):
'''
Split a number into `n` chunks, as equally as possible.
Args:
tot: the number that will be split into chunks
n: the number of chunks to split into
Example:
>>> chunks(53, n=8)
[7, 6, 7, 6, 7, 7, 6, 7]
>>> sum(chunks(53, n=8)) == 53
True
Returns:
chunks (list): the chunks
'''
delta = tot / n
intrange = [ int(round(i * delta)) for i in range(n+1) ]
pairs = zip(intrange, intrange[1:])
return list(map(lambda x: x[1]-x[0], pairs))
def _sample(args):
'''
Worker function to perform lookup-sampling.
Args:
args (tuple): contains the lookup table, the CDF
range, and the number of samples to look up.
Returns:
samples (list): the samples
'''
lookup, rnge, n = args
samps = []
for _ in range(n):
x = random()
for i, r in enumerate(rnge):
if x <= r:
break
samps.append( lookup[ rnge[ i ] ] )
return samps
class DistSampler:
'''
Approximate distribution sampling via CDF inversion using lookups.
Example:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
vmcdf = lambda x: scipy.stats.vonmises.cdf(x, kappa=kappa)
vm = DistSampler(vmcdf, xmin=-np.pi, xmax=np.pi)
plt.hist(vm.sample(k=1e4), bins=16)
plt.show()
'''
def __init__(self, cdf, xmin, xmax, resolution=100):
'''
Args:
cdf (callable): vectorized callable CDF function
xmin (float): left boundary of the approximation domain
xmax (float): right boundary of the approximation domain
resolution (int): accuracy of the approximation (this
has a big impact on sampling speed)
'''
domain = np.linspace(xmin, xmax, resolution)
self.rnge = cdf(domain)
self.lookup = { r: x for r, x in zip(self.rnge, domain) }
self.procs = 2 * cpu_count()
self.pool = Pool(self.procs)
def sample(self, k=1):
'''
Draw samples from the distribution. Uses multiprocessing for speed.
Args:
k (int): Number of samples to draw. Use this instead of calling
this method with `k=1` repeatedly!
Returns:
samples (list): the approximate samples.
'''
k = int(k)
chnks = chunks(k, n=2*self.procs)
args = [ (self.lookup, self.rnge, n) for n in chnks ]
res = self.pool.map_async(_sample, args)
samps = flatten(res.get())
return samps
if __name__ == '__main__':
import matplotlib.pyplot as plt
#
# Eg: Von Mises Distribution
#
print('Von Mises Distribution')
mu = 0
kappa = 1.25
vmpdf = lambda x: np.exp(kappa * np.cos(x - mu)) / (2 * np.pi * np.i0(kappa))
vmcdf = pdf2cdf(vmpdf, xmin=-np.pi, xmax=np.pi)
vm = DistSampler(vmcdf, xmin=-np.pi, xmax=np.pi)
plt.hist(vm.sample(k=1e5), bins=16)
plt.show()
# %timeit vm.sample(k=1e4)
# 83.4 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# => about 8us per sample
#
# Eg: Normal Distribution
#
print('Normal Distribution')
mu = 0
sigma = 1
norm_pdf = lambda x: np.sqrt(2 * np.pi * sigma**2)**(-1) * np.exp(-(x - mu)**2 / (2 * sigma**2))
norm_cdf = pdf2cdf(norm_pdf, xmin=-5, xmax=5)
norm = DistSampler(norm_cdf, xmin=-5, xmax=5)
plt.hist(norm.sample(k=1e5), bins=16)
plt.show()
# %timeit norm.sample(k=1e4)
# 80.3 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# => about 8us per sample
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment