Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active June 26, 2024 18:25
Show Gist options
  • Save vankesteren/1854cf92b6ec26a3bb9628048bb2b9b6 to your computer and use it in GitHub Desktop.
Save vankesteren/1854cf92b6ec26a3bb9628048bb2b9b6 to your computer and use it in GitHub Desktop.
Comparing lintsampler to basic uniform importance sampling
# Comparing lintsampler to basic uniform importance sampling
from scipy.stats import norm, uniform
import numpy as np
import matplotlib.pyplot as plt
from lintsampler import LintSampler
NSAMPLES = 1000000
# GMM example
def gmm_pdf(x):
mu = np.array([-3.0, 0.5, 2.5])
sig = np.array([1.0, 0.25, 0.75])
w = np.array([0.4, 0.25, 0.35])
return np.sum([w[i] * norm.pdf(x, mu[i], sig[i]) for i in range(3)], axis=0)
# importance sampling
rng = np.random.default_rng(42)
propdist = uniform(-12, 24)
proposals = propdist.rvs(NSAMPLES, random_state=rng)
weights = gmm_pdf(proposals) # / propdist.pdf(proposals) not needed because uniform!
importance_samples = np.random.choice(proposals, NSAMPLES, p=weights / weights.sum())
# plot
bins = np.linspace(-12, 12, 200)
plt.hist(importance_samples, bins=bins, density=True, label="Samples", fc="goldenrod")
plt.plot(bins, gmm_pdf(bins), label="True PDF", c='teal')
plt.show()
# compare to lintsamples
rng = np.random.default_rng(42)
fixedgrid = np.linspace(-12, 12, 33)
lintsampler_samples = LintSampler(fixedgrid,pdf=gmm_pdf,vectorizedpdf=True,seed=rng).sample(N=NSAMPLES)
plt.hist(lintsampler_samples, bins=bins, density=True, label="Samples", fc="goldenrod")
plt.plot(bins, gmm_pdf(bins), label="True PDF", c='teal')
plt.show()
# compare log-likelihood
np.log(gmm_pdf(lintsampler_samples)).sum()
np.log(gmm_pdf(importance_samples)).sum() # higher log-likelihood!
# Doughnut example
def circles_pdf(x):
c1 = np.array([-2.0, -2.0])
r1 = 1.0
c2 = np.array([2.0, 2.0])
r2 = 1.0
w = 0.4
v1 = x - c1
v2 = x - c2
av1 = np.linalg.norm(v1, axis=-1)[:, None]
av2 = np.linalg.norm(v2, axis=-1)[:, None]
pt1 = np.zeros_like(x)
pt2 = np.zeros_like(x)
m1 = (av1 == 0).squeeze()
m2 = (av2 == 0).squeeze()
pt1[~m1] = c1 + r1 * v1[~m1] / av1[~m1]
pt2[~m2] = c2 + r2 * v2[~m2] / av2[~m2]
pt1[m1] = c1 + r1 * np.array([1.0, 0.0])
pt2[m2] = c2 + r2 * np.array([1.0, 0.0])
d1 = np.linalg.norm(x - pt1, axis=-1)
d2 = np.linalg.norm(x - pt2, axis=-1)
return np.exp(-0.5 * d1**2 / w**2) + np.exp(-0.5 * d2**2 / w**2)
# importance sampling
rng = np.random.default_rng(42)
proposals = uniform(-4, 8).rvs(2*NSAMPLES, random_state=rng).reshape(NSAMPLES, 2)
imp_weights = circles_pdf(proposals)
idx = np.random.choice(NSAMPLES, NSAMPLES, p=imp_weights / imp_weights.sum())
importance_samples = proposals[idx,:]
# visual
plt.hist2d(importance_samples[:,0], importance_samples[:,1], 128, [[-4, 4], [-4, 4]], cmap='inferno')
plt.show()
# compare to lintsamples
rng = np.random.default_rng(42)
N_grid = 128
edges = np.linspace(-4, 4, N_grid + 1)
lintsampler_samples = LintSampler((edges,edges),pdf=circles_pdf,seed=rng,vectorizedpdf=True).sample(N=NSAMPLES)
# visual
plt.hist2d(lintsampler_samples[:,0], lintsampler_samples[:,1], 128, [[-4, 4], [-4, 4]], cmap='inferno')
plt.show()
# compare log-likelihood
np.log(circles_pdf(lintsampler_samples)).sum()
np.log(circles_pdf(importance_samples)).sum() # higher log-likelihood!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment