Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Experimenting with ArviZ
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pymc4 as pm4\n",
"import numpy as np\n",
"import arviz as az\n",
"import seaborn as sns\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"import matplotlib.pyplot as plt\n",
"from pprint import pprint\n",
"\n",
"dtype = tf.float32\n",
"plt.style.use(\"arviz-darkgrid\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"mu = np.zeros(2, dtype=np.float32)\n",
"cov = np.array([[1, 0.8], [0.8, 1]], dtype=np.float32)\n",
"data = np.random.multivariate_normal(mu, cov, size=10000)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"@pm4.model\n",
"def model():\n",
" density = yield pm4.MvNormal(\"density\", loc=mu, covariance_matrix=cov)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Do not pass `graph_parents`. They will no longer be used.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow_probability/python/math/minimize.py:77: calling <lambda> (from tensorflow_probability.python.vi.optimization) with loss is deprecated and will be removed after 2020-07-01.\n",
"Instructions for updating:\n",
"The signature for `trace_fn`s passed to `minimize` has changed. Trace functions now take a single `traceable_quantities` argument, which is a `tfp.math.MinimizeTraceableQuantities` namedtuple containing `traceable_quantities.loss`, `traceable_quantities.gradients`, etc. Please update your `trace_fn` definition.\n"
]
}
],
"source": [
"approx_model = model()\n",
"mean_field = pm4.fit(approx_model)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"posterior_samples = mean_field.approximation.sample(5000)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/site-packages/arviz/data/base.py:146: UserWarning: More chains (5000) than draws (2). Passed array should have shape (chains, draws, *shape)\n",
" UserWarning,\n"
]
}
],
"source": [
"trace = az.from_dict(posterior=posterior_samples)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inference data with groups:\n",
"\t> posterior\n"
]
}
],
"source": [
"print(trace)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment