Skip to content

Instantly share code, notes, and snippets.

@Sayam753
Created July 1, 2020 19:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sayam753/cc5126279932cffd65064bdc44754c2a to your computer and use it in GitHub Desktop.
Save Sayam753/cc5126279932cffd65064bdc44754c2a to your computer and use it in GitHub Desktop.
Flattening and Full Rank ADVI in PyMC4
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import arviz as az\n",
"import collections\n",
"import itertools\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pymc4 as pm\n",
"import seaborn as sns\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"from typing import Dict\n",
"\n",
"from tensorflow_probability.python.internal import dtype_util\n",
"\n",
"tfd = tfp.distributions\n",
"tfb = tfp.bijectors\n",
"\n",
"plt.style.use('arviz-darkgrid')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = np.random.normal(12, 2.2, 200)\n",
"\n",
"@pm.model\n",
"def model():\n",
" mu = yield pm.Normal('mu', 0, 10)\n",
" sigma = yield pm.Exponential('sigma', 1)\n",
" ll = yield pm.Normal('ll', mu, sigma, observed=data)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model/mu', 'model/__log_sigma']\n"
]
}
],
"source": [
"model = model()\n",
"state, deterministics_names = pm.inference.utils.initialize_sampling_state(model)\n",
"unobserved_keys = state.all_unobserved_values.keys()\n",
"print(list(unobserved_keys))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"VarMap = collections.namedtuple('VarMap', 'var, slc, shp, dtyp')\n",
"\n",
"\n",
"class ArrayOrdering:\n",
" \"\"\"\n",
" An ordering for an array space\n",
" \"\"\"\n",
"\n",
" def __init__(self, free_rvs: Dict[str, tf.Tensor]):\n",
" self.free_rvs = free_rvs\n",
" self.by_name = {}\n",
" self.size = 0\n",
"\n",
" for name, tensor in free_rvs.items():\n",
" flat_shape = int(np.prod(tensor.shape.as_list()))\n",
" slc = slice(self.size, self.size + flat_shape)\n",
" self.by_name[name] = VarMap(name, slc, tensor.shape, tensor.dtype)\n",
" self.size += flat_shape\n",
"\n",
" def flatten(self):\n",
" flattened_tensor = [tf.reshape(var, shape=[-1]) for var in self.free_rvs.values()]\n",
" return tf.concat(flattened_tensor, axis=0)\n",
" \n",
" def split(self, flatten_tensor):\n",
" flat_state = dict()\n",
" for param in self.free_rvs:\n",
" _, slc, shape, dtype = self.by_name[param]\n",
" flat_state[param] = tf.cast(tf.reshape(flatten_tensor[slc], shape), dtype)\n",
" return flat_state"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"@tf.function(autograph=False)\n",
"def logpfn(*values, **kwargs):\n",
" if kwargs and values:\n",
" raise TypeError(\"Either list state should be passed or a dict one\")\n",
" \n",
" val = ArrayOrdering(state.all_unobserved_values).split(values[0])\n",
" _, st = pm.flow.evaluate_meta_model(model, values=val)\n",
" return st.collect_log_prob()\n",
"\n",
"def vectorize_logp_function(logpfn):\n",
" def vectorized_logpfn(*q_samples):\n",
" return tf.vectorized_map(lambda samples: logpfn(*samples), q_samples)\n",
"\n",
" return vectorized_logpfn\n",
"\n",
"target_log_prob = vectorize_logp_function(logpfn)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def _build_posterior():\n",
" order = ArrayOrdering(state.all_unobserved_values)\n",
" flattened_shape = order.size\n",
" dtype = dtype_util.common_dtype(state.all_unobserved_values.values(), dtype_hint=tf.float32)\n",
" loc = tf.Variable(tf.random.normal([flattened_shape]), name=\"mu\")\n",
" scale_tril = tfb.FillScaleTriL(\n",
" diag_bijector=tfb.Chain(\n",
" [\n",
" tfb.Shift(tf.cast(1e-3, dtype)), # diagonal offset\n",
" tfb.Softplus(),\n",
" tfb.Shift(tf.cast(np.log(np.expm1(1.0)), dtype)), # initial scale\n",
" ]\n",
" ),\n",
" diag_shift=None,\n",
" )\n",
"\n",
" cov_matrix = tfp.util.TransformedVariable(\n",
" tf.eye(flattened_shape), scale_tril, name=\"sigma\"\n",
" )\n",
" return tfd.MultivariateNormalTriL(loc=loc, scale_tril=cov_matrix)\n",
"\n",
" \n",
"approx = _build_posterior()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"opt = pm.adam()\n",
"\n",
"@tf.function(autograph=False)\n",
"def run_approximation():\n",
" losses = tfp.vi.fit_surrogate_posterior(\n",
" target_log_prob_fn=target_log_prob,\n",
" surrogate_posterior=approx,\n",
" num_steps=40000,\n",
" optimizer=opt\n",
" )\n",
" return losses"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow_probability/python/math/minimize.py:77: calling <lambda> (from tensorflow_probability.python.vi.optimization) with loss is deprecated and will be removed after 2020-07-01.\n",
"Instructions for updating:\n",
"The signature for `trace_fn`s passed to `minimize` has changed. Trace functions now take a single `traceable_quantities` argument, which is a `tfp.math.MinimizeTraceableQuantities` namedtuple containing `traceable_quantities.loss`, `traceable_quantities.gradients`, etc. Please update your `trace_fn` definition.\n"
]
}
],
"source": [
"mean_field = run_approximation()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2,), dtype=float32, numpy=array([12.167437 , 0.7683811], dtype=float32)>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"approx.mean()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(mean_field)\n",
"plt.yscale('log')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment