Last active
June 14, 2020 14:59
-
-
Save Sayam753/36bf35c482b705545eecb5353a8f8f6a to your computer and use it in GitHub Desktop.
Experimenting with ArviZ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pymc4 as pm4\n", | |
"import numpy as np\n", | |
"import arviz as az\n", | |
"import seaborn as sns\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_probability as tfp\n", | |
"import matplotlib.pyplot as plt\n", | |
"from pprint import pprint\n", | |
"\n", | |
"dtype = tf.float32\n", | |
"plt.style.use(\"arviz-darkgrid\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mu = np.zeros(2, dtype=np.float32)\n", | |
"cov = np.array([[1, 0.8], [0.8, 1]], dtype=np.float32)\n", | |
"data = np.random.multivariate_normal(mu, cov, size=10000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@pm4.model\n", | |
"def model():\n", | |
" density = yield pm4.MvNormal(\"density\", loc=mu, covariance_matrix=cov)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Do not pass `graph_parents`. They will no longer be used.\n", | |
"WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow_probability/python/math/minimize.py:77: calling <lambda> (from tensorflow_probability.python.vi.optimization) with loss is deprecated and will be removed after 2020-07-01.\n", | |
"Instructions for updating:\n", | |
"The signature for `trace_fn`s passed to `minimize` has changed. Trace functions now take a single `traceable_quantities` argument, which is a `tfp.math.MinimizeTraceableQuantities` namedtuple containing `traceable_quantities.loss`, `traceable_quantities.gradients`, etc. Please update your `trace_fn` definition.\n" | |
] | |
} | |
], | |
"source": [ | |
"approx_model = model()\n", | |
"mean_field = pm4.fit(approx_model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"posterior_samples = mean_field.approximation.sample(5000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.7/site-packages/arviz/data/base.py:146: UserWarning: More chains (5000) than draws (2). Passed array should have shape (chains, draws, *shape)\n", | |
" UserWarning,\n" | |
] | |
} | |
], | |
"source": [ | |
"trace = az.from_dict(posterior=posterior_samples)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Inference data with groups:\n", | |
"\t> posterior\n" | |
] | |
} | |
], | |
"source": [ | |
"print(trace)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment