Skip to content

Instantly share code, notes, and snippets.

@sharanry
Created June 6, 2018 01:15
Show Gist options
  • Save sharanry/6b6439d81b4b0e18a46b52ed542a8387 to your computer and use it in GitHub Desktop.
Save sharanry/6b6439d81b4b0e18a46b52ed542a8387 to your computer and use it in GitHub Desktop.
Using tensorflow_probability's HMC sampler with pymc4
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"import numpy as np\n",
"tfd = tf.contrib.distributions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class HMC(object):\n",
" def __init__(self, model, target_accept_rate=0.651, num_leapfrog_steps=3):\n",
"# tf.reset_default_graph()\n",
" self.target_accept_rate = target_accept_rate\n",
" # Tuning acceptance rates:\n",
" self.dtype = np.float32\n",
" \n",
" self.x = {}\n",
" self.step_size = {} \n",
" for i in model.named_vars:\n",
" self.x[i] = tf.get_variable(name='x', initializer=self.dtype(1))#, reuse=tf.AUTO_REUSE)\n",
" self.step_size[i] = tf.get_variable(name='step_size', initializer=self.dtype(1))\n",
" \n",
" def sample(self, model, draws, tune):\n",
" print(\"Intializing HMC Sampler...\")\n",
" dictionary = {}\n",
" for i in model.named_vars:\n",
" print(model.named_vars[i].distribution.log_prob)\n",
" dictionary[i] = tfp.mcmc.HamiltonianMonteCarlo(\n",
" target_log_prob_fn=model.named_vars[i].distribution.log_prob,\n",
" step_size=self.step_size[i],\n",
" num_leapfrog_steps=3)\n",
" # One iteration of the HMC\n",
" next_x = {}\n",
" other_results = {}\n",
" \n",
" next_x[i], other_results[i] = dictionary[i].one_step(\n",
" current_state=self.x[i],\n",
" previous_kernel_results=dictionary[i].bootstrap_results(self.x[i]))\n",
" x_update = {}\n",
" x_update[i] = self.x[i].assign(next_x[i])\n",
"\n",
" step_size_update = {}\n",
" step_size_update[i] = self.step_size[i].assign_add(\n",
" self.step_size[i] * tf.where(\n",
" tf.exp(tf.minimum(other_results[i].log_accept_ratio, 0.)) >\n",
" self.target_accept_rate,\n",
" 0.01, -0.01))\n",
" # Note, the adaptations are performed during warmup only.\n",
" warmup = {}\n",
" warmup[i] = tf.group([x_update[i], step_size_update[i]])\n",
" init = tf.global_variables_initializer()\n",
" with tf.Session() as sess:\n",
" sess.run(init)\n",
" # Warm up the sampler and adapt the step size\n",
" for _ in range(tune):\n",
" sess.run(warmup)\n",
" # Collect samples without adapting step size\n",
"# for i in model.named_vars:\n",
"# samples[i] = np.zeros([draws])\n",
" samples = []\n",
" for j in range(draws):\n",
" _, x_,= sess.run([x_update, self.x])\n",
" samples.append(x_)\n",
"# print(x_)\n",
" return samples\n",
"# print(samples.me~~an(), samples.std())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pymc4 as pm\n",
"from tensorflow_probability import edward2 as ed\n",
"from tensorflow_probability import distributions as tfd\n",
"from pymc4 import RandomVariable as RV"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with pm.Model() as model:\n",
" x = RV(\"x\", tfd.Normal(0.,1.))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"hmc = HMC(model)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<bound method Distribution.log_prob of <tf.distributions.Normal 'Normal' batch_shape=() event_shape=() dtype=float32>>\n"
]
}
],
"source": [
"for i in model.named_vars:\n",
" print(model.named_vars[i].distribution.log_prob)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trace = hmc.sample(model, draws=1000, tune=500)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.\n",
" warnings.warn(\"The 'normed' kwarg is deprecated, and has been \"\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"sns.distplot([i['x'] for i in trace])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (pymc3)",
"language": "python",
"name": "pymc3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment