Skip to content

Instantly share code, notes, and snippets.

@sharanry
Created June 27, 2018 01:23
Show Gist options
  • Save sharanry/f44e874074b6cc5edbd613762cfe1ea5 to your computer and use it in GitHub Desktop.
Save sharanry/f44e874074b6cc5edbd613762cfe1ea5 to your computer and use it in GitHub Desktop.
Eight Schools with PyMC4
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Eight Schools Problem with PyMC4"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import os\n",
"import tensorflow as tf\n",
"import pymc4 as pm\n",
"from tensorflow_probability import edward2 as ed\n",
"from tensorflow_probability import distributions as tfd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n",
"from pymc4.inference.sampling.sample import sample\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x576 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"num_schools = 8 # number of schools\n",
"treatment_effects = np.array(\n",
" [28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32) # treatment effects\n",
"treatment_stddevs = np.array(\n",
" [15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32) # treatment SE\n",
"\n",
"fig, ax = plt.subplots()\n",
"plt.bar(range(num_schools), treatment_effects, yerr=treatment_stddevs)\n",
"plt.title(\"8 Schools treatment effects\")\n",
"plt.xlabel(\"School\")\n",
"plt.ylabel(\"Treatment effect\")\n",
"fig.set_size_inches(10, 8)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = pm.Model(num_schools=num_schools, treatment_effects=treatment_effects, treatment_stddevs=treatment_stddevs )\n",
"@model.define\n",
"def process(cfg):\n",
" avg_effect = ed.Normal(loc=0., scale=10., name=\"avg_effect\") # `mu` above\n",
" avg_stddev = ed.Normal(\n",
" loc=5., scale=1., name=\"avg_stddev\") # `log(tau)` above\n",
" school_effects_standard = ed.Normal(\n",
" loc=tf.zeros(cfg.num_schools),\n",
" scale=tf.ones(cfg.num_schools),\n",
" name=\"school_effects_standard\") # `theta_prime` above\n",
" school_effects = avg_effect + tf.exp(\n",
" avg_stddev) * school_effects_standard # `theta` above\n",
" treatment_effects = ed.Normal(\n",
" loc=school_effects,\n",
" scale=cfg.treatment_stddevs,\n",
" name=\"treatment_effects\") # `y` above\n",
" return treatment_effects"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'num_schools': 8,\n",
" 'treatment_effects': array([28., 8., -3., 7., -1., 1., 18., 12.], dtype=float32),\n",
" 'treatment_stddevs': array([15., 10., 16., 11., 9., 11., 10., 18.], dtype=float32)}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.observed"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('avg_effect',\n",
" VariableDescription(Dist=<class 'tensorflow.python.ops.distributions.normal.Normal'>, shape=TensorShape([]), rv=<ed.RandomVariable 'avg_effect' shape=() dtype=float32>)),\n",
" ('avg_stddev',\n",
" VariableDescription(Dist=<class 'tensorflow.python.ops.distributions.normal.Normal'>, shape=TensorShape([]), rv=<ed.RandomVariable 'avg_stddev' shape=() dtype=float32>)),\n",
" ('school_effects_standard',\n",
" VariableDescription(Dist=<class 'tensorflow.python.ops.distributions.normal.Normal'>, shape=TensorShape([Dimension(8)]), rv=<ed.RandomVariable 'school_effects_standard' shape=(8,) dtype=float32>))])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.unobserved"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ -32.547695 34.432945 40.320427 28.81765 115.495995 62.690655\n",
" 26.68259 -112.8604 ]\n"
]
}
],
"source": [
"with tf.Session():\n",
" print(model._f(model._cfg).eval())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n",
"/home/sharan/anaconda3/envs/pymc3/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
" if d.decorator_argspec is not None), _inspect.getargspec(target))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Acceptance rate: 0.5994\n"
]
}
],
"source": [
"trace = sample(model)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'avg_effect': array([-2.5910451, -2.5910451, -2.5910451, ..., 7.5722466, 5.421218 ,\n",
" 5.421218 ], dtype=float32),\n",
" 'avg_stddev': array([3.1497786, 3.1497786, 3.1497786, ..., 2.819227 , 2.3189976,\n",
" 2.3189976], dtype=float32),\n",
" 'school_effects_standard': array([[ 1.9479568 , 0.6442069 , -0.875897 , ..., 0.18446806,\n",
" 0.6115229 , 1.6996222 ],\n",
" [ 1.9479568 , 0.6442069 , -0.875897 , ..., 0.18446806,\n",
" 0.6115229 , 1.6996222 ],\n",
" [ 1.9479568 , 0.6442069 , -0.875897 , ..., 0.18446806,\n",
" 0.6115229 , 1.6996222 ],\n",
" ...,\n",
" [ 1.7085497 , -0.5002298 , 0.7994229 , ..., 0.534858 ,\n",
" 1.0336227 , 0.25752133],\n",
" [ 0.1164161 , 0.5631517 , 0.21413967, ..., -0.08510438,\n",
" 0.14131531, 0.37988228],\n",
" [ 0.1164161 , 0.5631517 , 0.21413967, ..., -0.08510438,\n",
" 0.14131531, 0.37988228]], dtype=float32)}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trace"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"school_effects_samples = (\n",
" trace['avg_effect'][:, np.newaxis] +\n",
" np.exp(trace['avg_stddev'])[:, np.newaxis] * trace['school_effects_standard'])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"E[avg_effect] = [-2.5910451 -2.5910451 -2.5910451 ... 7.5722466 5.421218 5.421218 ]\n",
"E[avg_stddev] = [3.1497786 3.1497786 3.1497786 ... 2.819227 2.3189976 2.3189976]\n",
"E[school_effects_standard] =\n",
"[ 0.676412 0.13477032 -0.20579918 0.1252802 -0.26034215 -0.13445626\n",
" 0.6088453 0.17541133]\n",
"E[school_effects] =\n",
"[14.047517 6.444718 1.7609924 6.1226816 1.209603 2.9346652\n",
" 12.386038 6.989637 ]\n"
]
}
],
"source": [
"print(\"E[avg_effect] = {}\".format(trace['avg_effect']))\n",
"print(\"E[avg_stddev] = {}\".format(trace['avg_stddev']))\n",
"print(\"E[school_effects_standard] =\")\n",
"print(trace['school_effects_standard'].mean(0))\n",
"print(\"E[school_effects] =\")\n",
"print(school_effects_samples[:, ].mean(0))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x720 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"fig, axes = plt.subplots(8, 2, sharex='col', sharey='col')\n",
"fig.set_size_inches(12, 10)\n",
"for i in range(num_schools):\n",
" axes[i][0].plot(school_effects_samples[:,i])\n",
" axes[i][0].title.set_text(\"School {} treatment effect chain\".format(i))\n",
" sns.kdeplot(school_effects_samples[:,i], ax=axes[i][1], shade=True)\n",
" axes[i][1].title.set_text(\"School {} treatment effect distribution\".format(i))\n",
"axes[num_schools - 1][0].set_xlabel(\"Iteration\")\n",
"axes[num_schools - 1][1].set_xlabel(\"School effect\")\n",
"fig.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"school_effects_low = np.array([\n",
" np.percentile(school_effects_samples[:, i], 2.5) for i in range(num_schools)\n",
"])\n",
"school_effects_med = np.array([\n",
" np.percentile(school_effects_samples[:, i], 50) for i in range(num_schools)\n",
"])\n",
"school_effects_hi = np.array([\n",
" np.percentile(school_effects_samples[:, i], 97.5)\n",
" for i in range(num_schools)\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inferred posterior mean: 6.49\n",
"Inferred posterior mean se: 10.22\n"
]
}
],
"source": [
"print(\"Inferred posterior mean: {0:.2f}\".format(\n",
" np.mean(school_effects_samples[:,])))\n",
"print(\"Inferred posterior mean se: {0:.2f}\".format(\n",
" np.std(school_effects_samples[:,])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (pymc3)",
"language": "python",
"name": "pymc3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment