This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import tensorflow_probability as tfp\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"tfd = tfp.distributions\n", | |
"\n", | |
"target_log_prob_fn = tfd.Normal(loc=0., scale=0.1).log_prob\n", | |
"num_burnin_steps = 500\n", | |
"num_results = 500\n", | |
"num_chains = 256" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING: Logging before flag parsing goes to stderr.\n", | |
"W0728 17:09:56.062024 140082415449856 deprecation.py:323] From /home/colin/miniconda3/envs/scratch3.6/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py:2451: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n" | |
] | |
} | |
], | |
"source": [ | |
"init_step_size = tf.fill([num_chains], 0.25)\n", | |
"\n", | |
"kernel = tfp.mcmc.HamiltonianMonteCarlo(\n", | |
" target_log_prob_fn=target_log_prob_fn,\n", | |
" num_leapfrog_steps=8,\n", | |
" step_size=init_step_size)\n", | |
"kernel = tfp.mcmc.SimpleStepSizeAdaptation(\n", | |
" inner_kernel=kernel, num_adaptation_steps=num_burnin_steps, target_accept_prob=0.6)\n", | |
"\n", | |
"samples, [step_size, log_accept_ratio] = tfp.mcmc.sample_chain(\n", | |
" num_results=num_results,\n", | |
" num_burnin_steps=num_burnin_steps,\n", | |
" current_state=tf.zeros(num_chains),\n", | |
" kernel=kernel,\n", | |
" trace_fn=lambda _, pkr: [pkr.new_step_size,\n", | |
" pkr.inner_results.log_accept_ratio])\n", | |
"\n", | |
"p_accept = tf.reduce_mean(tf.exp(tf.minimum(log_accept_ratio, 0.))[num_burnin_steps:, :], 0)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAATBklEQVR4nO3df4zc913n8eeLXAkRJSIhm+Da7jkqPkQS0fSyGJ9yEoECcVN0SaXryREk/iMnQ0jvWqnSNQEJik7W5SQoXE4kJ5dWcUQhWNdysSA5LpgihEiTrkuuqWNCrMbXLLZi8+Ou5p+A3Td/zNfHeDPend2dndmdz/MhjXbmPZ/vzGe+mn3NZz/fz3w3VYUkqQ3fNOkOSJLGx9CXpIYY+pLUEENfkhpi6EtSQwx9SWrIkqGf5FuSvJDkfyc5muQXuvrVSZ5N8mr386q+bR5KcjzJK0lu76vfkuSl7r5HkmRtXpYkaZBhRvpvAj9UVe8GbgZ2JdkJPAgcrqrtwOHuNkluAHYDNwK7gEeTXNY91mPAXmB7d9k1wtciSVrCP1mqQfW+vfW33c23dZcC7gRu6+oHgD8EPtbVn6yqN4HXkhwHdiQ5AVxZVc8BJHkCuAt4ZrHnv+aaa2rbtm3LeU2S1LwjR478ZVXNLKwvGfoA3Uj9CPBdwK9W1fNJrquqUwBVdSrJtV3zzcAX+jaf72p/311fWB/0fHvp/UXAO9/5Tubm5obppiSpk+T/DKoPdSC3qs5X1c3AFnqj9psWe65BD7FIfdDz7a+q2aqanZl5yweVJGmFlrV6p6r+L71pnF3AG0k2AXQ/T3fN5oGtfZttAU529S0D6pKkMRlm9c5Mkm/vrl8B/DDwZ8AhYE/XbA/wVHf9ELA7yeVJrqd3wPaFbirobJKd3aqde/u2kSSNwTBz+puAA928/jcBB6vqd5I8BxxMch/wNeCDAFV1NMlB4GXgHPBAVZ3vHut+4HHgCnoHcBc9iCtJGq2s91Mrz87OlgdyJWl5khypqtmFdb+RK0kNMfQlqSGGviQ1xNCXpIYY+hKw7cHfZduDvzvpbkhrztCXpIYY+pLUEENfkhpi6EtSQwx9aREe4NW0MfQlqSGGvqaeo3XpHxn6ktQQQ1/q418FmnaGviQ1xNCXpIYM85+zpOY4xaNp5Uhfkhpi6EtD8iCvpoGhr+YY3mqZc/pqhkEvOdKX/DBQUwx9SWqIoS+tkMcGtBEZ+pLUEA/kqlnLGaX3t3V0r43Mkb4kNWTJ0E+yNcnnkxxLcjTJh7v6x5P8RZIXu8sdfds8lOR4kleS3N5XvyXJS919jyTJ2rwstcp5dmlxw0zvnAM+WlVfSvJtwJEkz3b3/XJV/WJ/4yQ3ALuBG4F3AL+f5J9V1XngMWAv8AXgaWAX8MxoXor0jy4E/4mH3z/hnkjry5Ij/ao6VVVf6q6fBY4BmxfZ5E7gyap6s6peA44DO5JsAq6squeqqoAngLtW/QqkRTjqly62rDn9JNuA9wDPd6UPJflykk8nuaqrbQZe79tsvqtt7q4vrA96nr1J5pLMnTlzZjldlCbODxqtZ0OHfpK3A58FPlJVX6c3VfMu4GbgFPBLF5oO2LwWqb+1WLW/qmaranZmZmbYLkoT4XEEbSRDhX6St9EL/M9U1ecAquqNqjpfVd8APgns6JrPA1v7Nt8CnOzqWwbUJUljMszqnQCfAo5V1Sf66pv6mn0A+Ep3/RCwO8nlSa4HtgMvVNUp4GySnd1j3gs8NaLXIUkawjCrd24F7gFeSvJiV/sZ4O4kN9ObojkB/CRAVR1NchB4md7Knwe6lTsA9wOPA1fQW7Xjyh1JGqMlQ7+q/pjB8/FPL7LNPmDfgPoccNNyOihJGh2/kStJDTH0Jakhhr4kNcTQl6SGeGplTQW/HCUNx5G+NiS/BSutjCN9aUT8ENJG4Ehfkhpi6EtSQwx9SWqIoS9JDfFArjY0D55Ky+NIX5IaYuhLa2Dh9wj8XoHWC0Nfkhpi6EtSQwx9bRhOkUirZ+hLY+QHlybN0Jekhhj6ktQQQ1+SGmLoS1JDDH1Jaojn3pHWkCt1tN440pekhhj60gT4F4AmxdDXhmNgSiu3ZOgn2Zrk80mOJTma5MNd/eokzyZ5tft5Vd82DyU5nuSVJLf31W9J8lJ33yNJsjYvS5I0yDAj/XPAR6vqe4CdwANJbgAeBA5X1XbgcHeb7r7dwI3ALuDRJJd1j/UYsBfY3l12jfC1SJKWsOTqnao6BZzqrp9NcgzYDNwJ3NY1OwD8IfCxrv5kVb0JvJbkOLAjyQngyqp6DiDJE8BdwDMjfD2aQk7nSKOzrDn9JNuA9wDPA9d1HwgXPhiu7ZptBl7v22y+q23uri+sD3qevUnmksydOXNmOV2UJC1i6NBP8nbgs8BHqurrizUdUKtF6m8tVu2vqtmqmp2ZmRm2i5KkJQwV+kneRi/wP1NVn+vKbyTZ1N2/CTjd1eeBrX2bbwFOdvUtA+pSkzzNsiZhmNU7AT4FHKuqT/TddQjY013fAzzVV9+d5PIk19M7YPtCNwV0NsnO7jHv7dtGkjQGw5yG4VbgHuClJC92tZ8BHgYOJrkP+BrwQYCqOprkIPAyvZU/D1TV+W67+4HHgSvoHcD1IK4kjdEwq3f+mMHz8QDvvcQ2+4B9A+pzwE3L6aAkaXT8Rq4kNcSzbGrd8iCnNHqO9KUJcxWPxsnQl6SGGPqS1BBDX5IaYuhrXXFuW1pbhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiOfe0brjsk1p7TjSl6SGGPqS1BBDX5IaYuhLUkMMfWmd8fz6WkuGviQ1xCWbWhcc2UrjYehL64QffBoHp3ckqSGGviQ1xNCX1imne7QWDH1Jaoihr4lwLbo0GUuGfpJPJzmd5Ct9tY8n+YskL3aXO/rueyjJ8SSvJLm9r35Lkpe6+x5JktG/HEnSYoYZ6T8O7BpQ/+Wqurm7PA2Q5AZgN3Bjt82jSS7r2j8G7AW2d5dBjylJWkNLrtOvqj9Ksm3Ix7sTeLKq3gReS3Ic2JHkBHBlVT0HkOQJ4C7gmZV0WtPDKR5pvFYzp/+hJF/upn+u6mqbgdf72sx3tc3d9YV1SdIYrTT0HwPeBdwMnAJ+qasPmqevReoDJdmbZC7J3JkzZ1bYRWnj84C3Rm1FoV9Vb1TV+ar6BvBJYEd31zywta/pFuBkV98yoH6px99fVbNVNTszM7OSLkpTyQ8BrdaKzr2TZFNVnepufgC4sLLnEPAbST4BvIPeAdsXqup8krNJdgLPA/cC/3V1XZfaYdBrVJYM/SS/CdwGXJNkHvh54LYkN9ObojkB/CRAVR1NchB4GTgHPFBV57uHup/eSqAr6B3A9SCuJI3ZMKt37h5Q/tQi7fcB+wbU54CbltU7SdJI+Y1cSWqIoS9JDfGfqGisPCApTZYjfUlqiKEvSQ0x9DU2Tu1Ik2foS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+tIG5Hn1tVKGviQ1xNCXpIYY+pLUEENf2sCc29dyGfqS1BBDX5Ia4j9R0Zpx2kFafxzpS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpoArpTQsQ1+SGuI6fY2co05p/VpypJ/k00lOJ/lKX+3qJM8mebX7eVXffQ8lOZ7klSS399VvSfJSd98jSTL6lyNJWsww0zuPA7sW1B4EDlfVduBwd5skNwC7gRu7bR5Nclm3zWPAXmB7d1n4mJKkNbZk6FfVHwF/vaB8J3Cgu34AuKuv/mRVvVlVrwHHgR1JNgFXVtVzVVXAE33bSJLGZKUHcq+rqlMA3c9ru/pm4PW+dvNdbXN3fWF9oCR7k8wlmTtz5swKuyhJWmjUq3cGzdPXIvWBqmp/Vc1W1ezMzMzIOidNM8+tr2GsdPXOG0k2VdWpburmdFefB7b2tdsCnOzqWwbUNUUMHGn9W+lI/xCwp7u+B3iqr747yeVJrqd3wPaFbgrobJKd3aqde/u2kSSNyZIj/SS/CdwGXJNkHvh54GHgYJL7gK8BHwSoqqNJDgIvA+eAB6rqfPdQ99NbCXQF8Ex3kSSN0ZKhX1V3X+Ku916i/T5g34D6HHDTsnonacUuTLedePj9E+6J1hNPwyBJDTH0JakhnntHq+IUwvrjKiotxpG+RsKgkTYGQ1+SGmLoS1POb+qqn6EvSQ0x9CWpIYa+JDXEJZtaEeeIpY3Jkb4kNcTQl6SGGPqS1BBDX5IaYuhLUkNcvSM1on/FlSfIa5ehr2Vxqaa0sTm9I0kNMfQlqSGGviQ1xNCXGuTplttl6EtSQ1y9o0U5GpSmiyN9XZKBL00fQ1+Sc/wNMfQlqSGGviQ1ZFUHcpOcAM4C54FzVTWb5Grgt4BtwAng31TV33TtHwLu69r/+6r6vdU8v6TVcUqnPaMY6f9gVd1cVbPd7QeBw1W1HTjc3SbJDcBu4EZgF/BokstG8PySpCGtxfTOncCB7voB4K6++pNV9WZVvQYcB3aswfNLki5htaFfwP9KciTJ3q52XVWdAuh+XtvVNwOv920739XeIsneJHNJ5s6cObPKLkqSLljtl7NuraqTSa4Fnk3yZ4u0zYBaDWpYVfuB/QCzs7MD20iSlm9VoV9VJ7ufp5P8Nr3pmjeSbKqqU0k2Aae75vPA1r7NtwAnV/P8Whse3JOm14qnd5J8a5Jvu3Ad+FHgK8AhYE/XbA/wVHf9ELA7yeVJrge2Ay+s9PkljV7/B75f2JpOqxnpXwf8dpILj/MbVfU/k3wROJjkPuBrwAcBqupokoPAy8A54IGqOr+q3kuSlmXFoV9VXwXePaD+V8B7L7HNPmDfSp9TkrQ6fiNXkhriqZX1/zl/K/B9MO0MfflLLjXE0G+YYS+1xzl9SWqIoS9JDTH0JS3KL2lNF0Nfkhpi6DfGUZtWauF7x/fRxuTqnUb5C6uV8r2zsTnSl6SGGPqS1BBDX5Ia4px+I5yHlQSO9CWtgqvBNh5H+lPOX0hJ/Qz9KWXYSxrE6Z0pZOBr3Bab5nEKaH0x9CWNnEG/fhn6ktQQ5/Q3sAsjqRMPv3/CPZF6Fo7uHe2vP4b+BuQvlqSVMvQ3EMNd0mo5py9prDzIO1mGviQ1xOmddc4RkabFYu9lFyWMT6pq0n1Y1OzsbM3NzU26G2tu4ZvesFfLDP/VS3KkqmYX1sc+0k+yC/gvwGXAr1XVw+Puw3oxaHRj2EuDfw/8IBiNsY70k1wG/DnwI8A88EXg7qp6+VLbbLSR/mJB7iheGh0/BBZ3qZH+uEP/XwAfr6rbu9sPAVTVf7rUNmsd+sPMJS7WxgCXNo7+gde0f2isl9D/18Cuqvq33e17gO+vqg8taLcX2Nvd/G7glRF24xrgL0f4eBud++Ni7o+LuT8utpH2xz+tqpmFxXHP6WdA7S2fOlW1H9i/Jh1I5gZ9+rXK/XEx98fF3B8Xm4b9Me51+vPA1r7bW4CTY+6DJDVr3KH/RWB7kuuTfDOwGzg05j5IUrPGOr1TVeeSfAj4PXpLNj9dVUfH2QfWaNpoA3N/XMz9cTH3x8U2/P5Y91/OkiSNjufekaSGGPqS1JCpD/0kVyd5Nsmr3c+rFml7WZI/TfI74+zjOA2zP5JsTfL5JMeSHE3y4Un0dS0l2ZXklSTHkzw44P4keaS7/8tJ/vkk+jkuQ+yPH+/2w5eT/EmSd0+in+Oy1P7oa/d9Sc5330HaEKY+9IEHgcNVtR043N2+lA8Dx8bSq8kZZn+cAz5aVd8D7AQeSHLDGPu4prrTgfwq8D7gBuDuAa/vfcD27rIXeGysnRyjIffHa8APVNX3Av+RKTigeSlD7o8L7f4zvYUpG0YLoX8ncKC7fgC4a1CjJFuA9wO/NqZ+TcqS+6OqTlXVl7rrZ+l9EG4eWw/X3g7geFV9tar+DniS3n7pdyfwRPV8Afj2JJvG3dExWXJ/VNWfVNXfdDe/QO87NtNqmPcHwL8DPgucHmfnVquF0L+uqk5BL8yAay/R7leA/wB8Y1wdm5Bh9wcASbYB7wGeX/Oejc9m4PW+2/O89UNtmDbTYrmv9T7gmTXt0WQtuT+SbAY+APy3MfZrJKbin6gk+X3gOwfc9bNDbv9jwOmqOpLktlH2bRJWuz/6Huft9EYyH6mqr4+ib+vEMKcDGeqUIVNi6Nea5Afphf6/XNMeTdYw++NXgI9V1flkUPP1aypCv6p++FL3JXkjyaaqOtX9eT7oT7FbgX+V5A7gW4Ark/x6Vf3EGnV5TY1gf5DkbfQC/zNV9bk16uqkDHM6kJZOGTLUa03yvfSmP99XVX81pr5NwjD7YxZ4sgv8a4A7kpyrqv8xni6uXAvTO4eAPd31PcBTCxtU1UNVtaWqttE7NcQfbNTAH8KS+yO9d/KngGNV9Ykx9m1chjkdyCHg3m4Vz07g/12YFptCS+6PJO8EPgfcU1V/PoE+jtOS+6Oqrq+qbV1m/HfgpzdC4EMbof8w8CNJXqX3z1seBkjyjiRPT7RnkzHM/rgVuAf4oSQvdpc7JtPd0auqc8CF04EcAw5W1dEkP5Xkp7pmTwNfBY4DnwR+eiKdHYMh98fPAd8BPNq9HzbOfzZapiH3x4blaRgkqSEtjPQlSR1DX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXkHwDB2OKlf3AgjAAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.hist(samples.numpy().flatten(), bins='auto');" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "scratch3.6", | |
"language": "python", | |
"name": "scratch3_6" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (
28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
)Running this instead:
Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.