Skip to content

Instantly share code, notes, and snippets.

@iamirmasoud
Created July 31, 2022 20:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iamirmasoud/2b83de47ee93b9a27a7694495182164d to your computer and use it in GitHub Desktop.
Save iamirmasoud/2b83de47ee93b9a27a7694495182164d to your computer and use it in GitHub Desktop.
Importance Sampling in Python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Importance Sampling\n",
"---\n",
"Using sampling to approximate a distribution\n",
"\n",
"$$E[f(x)] = \\int f(x)p(x) dx \\approx \\frac{1}{n}\\sum_{i} f(x_i)$$\n",
"where $ x \\sim p(x)$\n",
"\n",
"$$E[f(x)] = \\int f(x)p(x) dx = \\int f(x)\\frac{p(x)}{q(x)}q(x) dx \\approx \\frac{1}{n} \\sum_{i} f(x_i)\\frac{p(x_i)}{q(x_i)}$$\n",
"\n",
"where $ x \\sim q(x)$\n",
"\n",
"Idea of importance sampling: draw the sample from a proposal distribution and re-weight the integral using importance weights so that the correct distribution is targeted\n",
"\n",
"$$Var(X) = E[X^2] - E[X]^2$$\n",
"\n",
"**Reference**\n",
"\n",
"- [1](https://www.youtube.com/watch?v=3Mw6ivkDVZc)\n",
"- [2](https://astrostatistics.psu.edu/su14/lectures/cisewski_is.pdf)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jeremy.zhang/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n",
"/Users/jeremy.zhang/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"import numpy as np\n",
"import scipy.stats as stats\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {},
"outputs": [],
"source": [
"def f_x(x):\n",
" return 1/(1 + np.exp(-x))\n",
"\n",
"def distribution(mu=0, sigma=1):\n",
" # return probability given a value\n",
" distribution = stats.norm(mu, sigma)\n",
" return distribution"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x115b111d0>"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=[6, 4])\n",
"x = np.linspace(0, 4, 50) # x ranges from 0 to 4\n",
"y = [f_x(i) for i in x]\n",
"\n",
"plt.plot(x, y, label=\"$f(x)$\")\n",
"\n",
"plt.xlabel(\"x\", size=18)\n",
"plt.ylabel(\"y\", size=18)\n",
"plt.legend(prop={\"size\": 14})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sampling"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [],
"source": [
"# pre-setting\n",
"n = 1000\n",
"\n",
"mu_target = 3.5\n",
"sigma_target = 1\n",
"mu_appro = 3\n",
"sigma_appro = 1\n",
"\n",
"p_x = distribution(mu_target, sigma_target)\n",
"q_x = distribution(mu_appro, sigma_appro)"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jeremy.zhang/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.\n",
" warnings.warn(\"The 'normed' kwarg is deprecated, and has been \"\n",
"/Users/jeremy.zhang/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.\n",
" warnings.warn(\"The 'normed' kwarg is deprecated, and has been \"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x117979f98>"
]
},
"execution_count": 170,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=[10, 4])\n",
"\n",
"sns.distplot([np.random.normal(mu_target, sigma_target) for _ in range(3000)], label=\"distribution $p(x)$\")\n",
"sns.distplot([np.random.normal(mu_appro, sigma_appro) for _ in range(3000)], label=\"distribution $q(x)$\")\n",
"\n",
"plt.title(\"Distributions\", size=16)\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 178,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"simulate value 0.9542816022260111\n"
]
}
],
"source": [
"# value\n",
"s = 0\n",
"for i in range(n):\n",
" # draw a sample\n",
" x_i = np.random.normal(mu_target, sigma_target)\n",
" s += f_x(x_i)\n",
"print(\"simulate value\", s/n)"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average 0.9495227171370471 variance 0.3043862985463373\n"
]
}
],
"source": [
"# calculate value sampling from a different distribution\n",
"\n",
"value_list = []\n",
"for i in range(n):\n",
" # sample from different distribution\n",
" x_i = np.random.normal(mu_appro, sigma_appro)\n",
" value = f_x(x_i)*(p_x.pdf(x_i) / q_x.pdf(x_i))\n",
" \n",
" value_list.append(value)\n",
"\n",
"print(\"average {} variance {}\".format(np.mean(value_list), np.var(value_list)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Different $q(x)$"
]
},
{
"cell_type": "code",
"execution_count": 179,
"metadata": {},
"outputs": [],
"source": [
"# pre-setting\n",
"n = 5000\n",
"\n",
"mu_target = 3.5\n",
"sigma_target = 1\n",
"mu_appro = 1\n",
"sigma_appro = 1\n",
"\n",
"p_x = distribution(mu_target, sigma_target)\n",
"q_x = distribution(mu_appro, sigma_appro)"
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jeremy.zhang/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.\n",
" warnings.warn(\"The 'normed' kwarg is deprecated, and has been \"\n",
"/Users/jeremy.zhang/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.\n",
" warnings.warn(\"The 'normed' kwarg is deprecated, and has been \"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x117564ac8>"
]
},
"execution_count": 182,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=[10, 4])\n",
"\n",
"sns.distplot([np.random.normal(mu_target, sigma_target) for _ in range(3000)], label=\"distribution $p(x)$\")\n",
"sns.distplot([np.random.normal(mu_appro, sigma_appro) for _ in range(3000)], label=\"distribution $q(x)$\")\n",
"\n",
"plt.title(\"Distributions\", size=16)\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 181,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average 0.9959984807502844 variance 83.36158359644132\n"
]
}
],
"source": [
"# calculate value sampling from a different distribution\n",
"\n",
"value_list = []\n",
"# need larger steps\n",
"for i in range(n):\n",
" # sample from different distribution\n",
" x_i = np.random.normal(mu_appro, sigma_appro)\n",
" value = f_x(x_i)*(p_x.pdf(x_i) / q_x.pdf(x_i))\n",
" \n",
" value_list.append(value)\n",
"\n",
"print(\"average {} variance {}\".format(np.mean(value_list), np.var(value_list)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment