Created
November 4, 2016 03:48
-
-
Save kforeman/ad91b2db75166a57ee29e8808ea4c665 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Random effects model in TensorFlow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import statsmodels.api as sm\n", | |
"import statsmodels.formula.api as smf" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Simulate some data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Weights of piglets by day since birth\n", | |
"$$\n", | |
"y \\sim \\mathcal{N}\\big(\\hat{y}, \\sigma\\big)\n", | |
"$$\n", | |
"$$\n", | |
"\\hat{y_{i,t}} = \\alpha + \\beta t + \\pi_{i}\n", | |
"$$\n", | |
"$$\n", | |
"\\pi \\sim \\mathcal{N}\\big(0, \\tau\\big)\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"alpha = 2.\n", | |
"beta = -0.5\n", | |
"tau = 0.8\n", | |
"sigma = 0.3\n", | |
"N = int(5e2)\n", | |
"T = int(1e2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"weight = np.empty(shape=[N,T])\n", | |
"weight[:,:] = alpha +\\\n", | |
" (beta * np.arange(T)) +\\\n", | |
" np.random.normal(0., tau, size=[N,1]) +\\\n", | |
" np.random.normal(0., sigma, size=[N,T])\n", | |
"weight = weight.flatten()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"pig = np.repeat(np.arange(N), T)\n", | |
"time = np.tile(np.arange(T), N)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>pig</th>\n", | |
" <th>time</th>\n", | |
" <th>weight</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>3.153095</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>2.264115</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>2.384547</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>1.689781</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0</td>\n", | |
" <td>4</td>\n", | |
" <td>0.968568</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" pig time weight\n", | |
"0 0 0 3.153095\n", | |
"1 0 1 2.264115\n", | |
"2 0 2 2.384547\n", | |
"3 0 3 1.689781\n", | |
"4 0 4 0.968568" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data = pd.DataFrame({'weight': weight, 'time': time, 'pig': pig})\n", | |
"data.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Mixed effects model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/kfor/anaconda/lib/python3.5/site-packages/statsmodels/regression/mixed_linear_model.py:160: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future\n", | |
" self._params = np.zeros(self.k_tot)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Mixed Linear Model Regression Results\n", | |
"===========================================================\n", | |
"Model: MixedLM Dependent Variable: weight \n", | |
"No. Observations: 50000 Method: REML \n", | |
"No. Groups: 500 Scale: 0.0905 \n", | |
"Min. group size: 100 Likelihood: -12532.5919\n", | |
"Max. group size: 100 Converged: Yes \n", | |
"Mean group size: 100.0 \n", | |
"-----------------------------------------------------------\n", | |
" Coef. Std.Err. z P>|z| [0.025 0.975]\n", | |
"-----------------------------------------------------------\n", | |
"Intercept 2.042 0.036 56.010 0.000 1.970 2.113\n", | |
"time -0.500 0.000 -10731.501 0.000 -0.500 -0.500\n", | |
"Intercept RE 0.661 0.140 \n", | |
"===========================================================\n", | |
"\n", | |
"CPU times: user 7.85 s, sys: 143 ms, total: 7.99 s\n", | |
"Wall time: 4.61 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"md = smf.mixedlm(\"weight ~ time\", data, groups=data[\"pig\"]) \n", | |
"mdf = md.fit() \n", | |
"print(mdf.summary())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## TensorFlow version" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create variables for parameters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"tf_alpha = tf.Variable(tf.zeros([1]))\n", | |
"tf_beta = tf.Variable(tf.zeros([1]))\n", | |
"tf_log_tau = tf.Variable(tf.zeros([1]))\n", | |
"tf_tau = tf.exp(tf_log_tau)\n", | |
"tf_pi = tf.Variable(tf.zeros([N]))\n", | |
"tf_log_sigma = tf.Variable(tf.zeros([1]))\n", | |
"tf_sigma = tf.exp(tf_log_sigma)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Convenience functions for likelihoods" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def normal_like(x, mu, tau):\n", | |
" return tf.reduce_sum((-tau * tf.square(x - mu) + tf.log(tau / np.pi / 2.)) / 2.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def gammaln(x):\n", | |
" # fast approximate gammaln from paul mineiro\n", | |
" # http://www.machinedlearnings.com/2011/06/faster-lda.html\n", | |
" logterm = tf.log(x * (1.0 + x) * (2.0 + x))\n", | |
" xp3 = 3.0 + x\n", | |
" return -2.081061466 - x + 0.0833333 / xp3 - logterm + (2.5 + x) * tf.log(xp3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def gamma_like(x, a, b):\n", | |
" return -1. * gammaln(a) + tf.log(tf.pow(beta, alpha)) - (beta * x) + tf.log(tf.pow(x, alpha - 1))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Generate predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf_pred = (tf_beta * time) + tf.gather(tf_pi, pig)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Generate negative log-likelihood" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"pi_like = normal_like(tf_pi, tf_alpha, tf_tau)\n", | |
"data_like = normal_like(tf_pred, weight, tf_sigma)\n", | |
"hyper_like = gamma_like(tf_tau, 1., 1.)\n", | |
"nll = -1. * (pi_like + data_like + tf_tau)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Setup optimization function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"optimizer = tf.train.AdagradOptimizer(0.5)\n", | |
"train = optimizer.minimize(nll)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Run optimization _(just iteratively for now, no stopping criteria)_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sess = tf.Session()\n", | |
"sess.run(tf.initialize_all_variables())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 3.20795e+06 0.0 -0.5 0.778801\n", | |
"500 63241.3 1.32537 -0.499652 1.74351\n", | |
"1000 55567.5 2.02064 -0.500014 1.30848\n", | |
"1500 48318.2 2.04174 -0.500028 1.05106\n", | |
"2000 41542.0 2.04179 -0.500028 0.918082\n", | |
"2500 35329.8 2.0418 -0.500028 0.854782\n", | |
"3000 29783.8 2.0418 -0.500028 0.826603\n", | |
"3500 25000.6 2.0418 -0.500028 0.814751\n", | |
"4000 21045.3 2.0418 -0.500028 0.810078\n", | |
"4500 17930.8 2.0418 -0.500028 0.808426\n", | |
"5000 27262.9 2.06572 -0.500085 0.821607\n" | |
] | |
} | |
], | |
"source": [ | |
"for step in range(5001):\n", | |
" sess.run(train)\n", | |
" if step % 500 == 0:\n", | |
" print(step, sess.run(nll)[0], sess.run(tf_alpha)[0], sess.run(tf_beta)[0], sess.run(1./ tf.sqrt(tf_tau))[0])" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"gist_id": "96089814d2977532330e", | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment