Skip to content

Instantly share code, notes, and snippets.

@tatsy
Created January 4, 2022 04:43
Show Gist options
  • Save tatsy/6d19e4e37e12dd75ff9505c7bb684f28 to your computer and use it in GitHub Desktop.
Save tatsy/6d19e4e37e12dd75ff9505c7bb684f28 to your computer and use it in GitHub Desktop.
Slice sampling
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Parameters\n",
"n_mutate = 10000"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def inv_dist_fun(x):\n",
" r = 0.3\n",
" mu1 = 5.0\n",
" mu2 = -5.0\n",
" x1 = x - mu1\n",
" x2 = x - mu2\n",
" e1 = np.exp(-x1 * x1) / np.sqrt(np.pi)\n",
" e2 = np.exp(-x2 * x2) / np.sqrt(np.pi)\n",
" return r * e1 + (1.0 - r) * e2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class SliceSampling(object):\n",
" def __init__(self, x, inv_dist_fun, sigma, seed=0):\n",
" self.rng = np.random.RandomState(seed)\n",
" self.x = x\n",
" self.inv_dist_fun = inv_dist_fun\n",
" self.sigma = sigma\n",
" self.w = 10.0\n",
"\n",
" def sample(self):\n",
" u = self.inv_dist_fun(self.x) * self.rng.random()\n",
" x_min = self.x\n",
" while self.inv_dist_fun(x_min) > u:\n",
" x_min -= self.w\n",
" x_max = self.x\n",
" while self.inv_dist_fun(x_max) > u:\n",
" x_max += self.w\n",
"\n",
" while True:\n",
" x1 = self.rng.uniform(x_min, x_max)\n",
" if self.inv_dist_fun(x1) > u:\n",
" self.x = x1\n",
" break\n",
" else:\n",
" if x1 < self.x:\n",
" x_min = x1\n",
" else:\n",
" x_max = x1\n",
"\n",
" return self.x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"sampler = SliceSampling(0.0, inv_dist_fun, 1.0)\n",
"n_burnin = max(n_mutate // 100, 100)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Burn-in\n",
"for i in range(n_burnin):\n",
" sampler.sample()\n",
"\n",
"# Sampling\n",
"samples = []\n",
"for i in range(n_mutate):\n",
" samples.append(sampler.sample())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"bins = 50\n",
"x_min = -8.0\n",
"x_max = 8.0\n",
"\n",
"fig = plt.figure(figsize=(8, 6))\n",
"ax = fig.add_subplot(111)\n",
"ax.hist(samples, bins=bins, range=(x_min, x_max), density=True, color='tab:blue', label=\"samples\")\n",
"xs = np.linspace(x_min, x_max, 1000)\n",
"ys = inv_dist_fun(xs)\n",
"ax.plot(xs, ys, color='tab:red', linestyle='--', label=\"target\")\n",
"ax.set_ylim([0.0, 0.5])\n",
"ax.legend(loc=\"upper right\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "02d3c230780613a338b307efe1c8520401457e802ba0c889aabda605c4f627d4"
},
"kernelspec": {
"display_name": "Python 3.8.11 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment