Skip to content

Instantly share code, notes, and snippets.

@Sayam753
Created August 12, 2020 03:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sayam753/34e3ad014424ebd2902727114520a582 to your computer and use it in GitHub Desktop.
Save Sayam753/34e3ad014424ebd2902727114520a582 to your computer and use it in GitHub Desktop.
Progress bar using tf.print
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.status.idle": "2020-08-12T03:18:02.583136Z",
"shell.execute_reply": "2020-08-12T03:18:02.582234Z",
"shell.execute_reply.started": "2020-08-12T03:17:52.954002Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"from tensorflow_probability.python.mcmc.transformed_kernel import make_transformed_log_prob\n",
"\n",
"tfb = tfp.bijectors\n",
"tfd = tfp.distributions\n",
"dtype = tf.float32"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2020-08-12T03:18:02.586585Z",
"iopub.status.busy": "2020-08-12T03:18:02.585929Z",
"iopub.status.idle": "2020-08-12T03:18:02.599161Z",
"shell.execute_reply": "2020-08-12T03:18:02.594731Z",
"shell.execute_reply.started": "2020-08-12T03:18:02.586514Z"
}
},
"outputs": [],
"source": [
"mu = 12\n",
"sigma = 2.2\n",
"data = np.random.normal(mu, sigma, size=200)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2020-08-12T03:18:02.606526Z",
"iopub.status.busy": "2020-08-12T03:18:02.606019Z",
"iopub.status.idle": "2020-08-12T03:18:02.694395Z",
"shell.execute_reply": "2020-08-12T03:18:02.691961Z",
"shell.execute_reply.started": "2020-08-12T03:18:02.606457Z"
}
},
"outputs": [],
"source": [
"model = tfd.JointDistributionSequential([\n",
" tfd.Exponential(0.1, name='e'), # sigma\n",
" tfd.Normal(loc=0, scale=10, name='n'), # mu\n",
" lambda n, e: tfd.Normal(loc=n, scale=e)\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2020-08-12T03:18:02.697303Z",
"iopub.status.busy": "2020-08-12T03:18:02.696615Z",
"iopub.status.idle": "2020-08-12T03:18:02.710323Z",
"shell.execute_reply": "2020-08-12T03:18:02.709011Z",
"shell.execute_reply.started": "2020-08-12T03:18:02.697232Z"
}
},
"outputs": [],
"source": [
"joint_log_prob = lambda *x: model.log_prob(x + (data,))\n",
"\n",
"unconstraining_bijectors = [\n",
" tfb.Exp(),\n",
" tfb.Identity()\n",
"]\n",
"\n",
"target_log_prob = make_transformed_log_prob(\n",
" joint_log_prob,\n",
" unconstraining_bijectors,\n",
" direction='forward',\n",
" enable_bijector_caching=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2020-08-12T03:18:02.716846Z",
"iopub.status.busy": "2020-08-12T03:18:02.716015Z",
"iopub.status.idle": "2020-08-12T03:18:02.763467Z",
"shell.execute_reply": "2020-08-12T03:18:02.762328Z",
"shell.execute_reply.started": "2020-08-12T03:18:02.716771Z"
}
},
"outputs": [],
"source": [
"parameters = model.sample(1)\n",
"parameters.pop()\n",
"dists = []\n",
"for i, parameter in enumerate(parameters):\n",
" shape = parameter[0].shape\n",
" loc = tf.Variable(\n",
" tf.random.normal(shape, dtype=dtype),\n",
" name='meanfield_%s_loc' % i,\n",
" dtype=dtype)\n",
" scale = tfp.util.TransformedVariable(\n",
" tf.fill(shape, value=tf.constant(0.02, dtype)),\n",
" tfb.Softplus(),\n",
" name='meanfield_%s_scale' % i,\n",
" )\n",
"\n",
" approx_parameter = tfd.Independent(tfd.Normal(loc=loc, scale=scale))\n",
" dists.append(approx_parameter)\n",
"\n",
"meanfield_advi = tfd.JointDistributionSequential(dists)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2020-08-12T03:18:02.766732Z",
"iopub.status.busy": "2020-08-12T03:18:02.765313Z",
"iopub.status.idle": "2020-08-12T03:18:14.500642Z",
"shell.execute_reply": "2020-08-12T03:18:14.499464Z",
"shell.execute_reply.started": "2020-08-12T03:18:02.766656Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|>>>>>>>>>>>>>>>>>>>>|"
]
}
],
"source": [
"num_steps = 10_000\n",
"num_cols = 20\n",
"it_break = num_steps // num_cols\n",
"\n",
"def trace_fn(traceable_quantities):\n",
" tf.cond(\n",
" tf.math.mod(traceable_quantities.step + 1, it_break) == 0,\n",
" lambda: tf.print(\n",
" tf.strings.reduce_join(\n",
" [\n",
" \"\\r|\",\n",
" tf.strings.reduce_join(\n",
" tf.repeat(\">\", (traceable_quantities.step + 1) // it_break, axis=0)\n",
" ),\n",
" tf.strings.reduce_join(\n",
" tf.repeat(\n",
" \".\",\n",
" num_cols - (traceable_quantities.step + 1) // it_break,\n",
" axis=0,\n",
" )\n",
" ),\n",
" \"|\",\n",
" ]\n",
" ),\n",
" end=\"\",\n",
" ),\n",
" lambda: tf.no_op(),\n",
" )\n",
" return traceable_quantities.loss\n",
"\n",
"opt = tf.optimizers.Adam(learning_rate=.5)\n",
"\n",
"def run_approximation():\n",
" loss_ = tfp.vi.fit_surrogate_posterior(\n",
" target_log_prob,\n",
" surrogate_posterior=meanfield_advi,\n",
" optimizer=opt,\n",
" num_steps=num_steps,\n",
" trace_fn=trace_fn\n",
" )\n",
" return loss_\n",
"\n",
"loss_ = run_approximation()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment