Skip to content

Instantly share code, notes, and snippets.

@daob
Created March 4, 2016 15:53
Show Gist options
  • Save daob/d3db5c60892654f212b2 to your computer and use it in GitHub Desktop.
Save daob/d3db5c60892654f212b2 to your computer and use it in GitHub Desktop.
An implementation of a finite mixture model with covariates ("latent class model") in TensorFlow
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf \n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def weight_variable(shape):\n",
" initial = tf.truncated_normal(shape, stddev=0.1)\n",
" return tf.Variable(initial)\n",
"\n",
"def bias_variable(shape):\n",
" initial = tf.constant(0.1, shape=shape)\n",
" return tf.Variable(initial)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Latent class example on data generated in R"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"n <- 1000 #1e8\n",
"J <- 3\n",
"\n",
"set.seed(654)\n",
"\n",
"Z1 <- rbinom(n, size = 1, prob = 0.5)\n",
"Z2 <- rbinom(n, size = 1, prob = 0.5)\n",
"\n",
"X <- rbinom(n, size = 1, prob = plogis(1 + 0.3*Z1 - 99*Z2))\n",
"\n",
"Y <- matrix(rbinom(n*J, size = 1, prob = 0.3 + 0.4*X), n)\n",
"\n",
"dat <- data.frame(X = X, Y=Y, Z1 = Z1, Z2 = Z2)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LG5.0 output\n",
"\n",
"This shows the output obtained when estimating the correct model in Latent GOLD 5.0.0.14161 on these data, in one go. Default settings are used except Bayes constant which is set to 0. "
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"options\n",
" maxthreads=4;\n",
" algorithm \n",
" tolerance=1e-008 emtolerance=0.01 emiterations=550 nriterations=50 ;\n",
" startvalues\n",
" seed=0 sets=30 tolerance=1e-005 iterations=50;\n",
" bayes\n",
" categorical=0 variances=0 latent=0 poisson=0;\n",
" montecarlo\n",
" seed=0 sets=0 replicates=500 tolerance=1e-008;\n",
" quadrature nodes=10;\n",
" missing includeall;\n",
" output \n",
" parameters=first betaopts=wl standarderrors profile probmeans=posterior\n",
" frequencies bivariateresiduals classification estimatedvalues=regression\n",
" predictionstatistics iterationdetails;\n",
"variables\n",
" dependent Y.1, Y.2, Y.3;\n",
" independent Z1, Z2;\n",
" latent\n",
" Class nominal 2;\n",
"equations\n",
" Class <- 1 + Z1 + Z2;\n",
" Y.1 <- 1 + Class;\n",
" Y.2 <- 1 + Class;\n",
" Y.3 <- 1 + Class;\n"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"Regression Parameters\t\t\t\t\t\t\n",
"term\t\t\tcoef\tWald(0)\tdf\tp-value\n",
"Class(1)\t<-\t=\"1\"\t0.0000\t6.8382\t1\t0.0090\n",
"Class(2)\t<-\t=\"1\"\t-0.7978\t\t\t\n",
"Class(1)\t<-\tZ1\t0.0000\t1.3411\t1\t0.25\n",
"Class(2)\t<-\tZ1\t-0.3994\t\t\t\n",
"Class(1)\t<-\tZ2\t0.0000\t0.0010\t1\t0.97\n",
"Class(2)\t<-\tZ2\t31.6722\t\t\t\n",
"\t\t\t\t\t\t\n",
"Y.1(0)\t<-\t=\"1\"\t0.0000\t21.5333\t1\t3.5e-6\n",
"Y.1(1)\t<-\t=\"1\"\t0.7336\t\t\t\n",
"Y.1\t<-\tClass(1)\t0.0000\t65.0564\t1\t7.3e-16\n",
"Y.1\t<-\tClass(2)\t-1.5342\t\t\t\n",
"\t\t\t\t\t\t\n",
"Y.2(0)\t<-\t=\"1\"\t0.0000\t35.8329\t1\t2.2e-9\n",
"Y.2(1)\t<-\t=\"1\"\t1.1681\t\t\t\n",
"Y.2\t<-\tClass(1)\t0.0000\t87.7631\t1\t7.4e-21\n",
"Y.2\t<-\tClass(2)\t-2.0554\t\t\t\n",
"\t\t\t\t\t\t\n",
"Y.3(0)\t<-\t=\"1\"\t0.0000\t35.4458\t1\t2.6e-9\n",
"Y.3(1)\t<-\t=\"1\"\t0.9487\t\t\t\n",
"Y.3\t<-\tClass(1)\t0.0000\t75.5371\t1\t3.6e-18\n",
"Y.3\t<-\tClass(2)\t-1.6387\t\t\t\n",
"\n",
" \tClass\t \t \n",
" \t1\t2\tOverall\n",
"Size\t0.3722\t0.6278\t \n",
"Y.1\t \t\t\n",
"0\t0.3244\t0.6901\t0.5540\n",
"1\t0.6756\t0.3099\t0.4460\n",
"Mean\t0.6756\t0.3099\t0.4460\n",
"Y.2\t \t\t\n",
"0\t0.2372\t0.7083\t0.5330\n",
"1\t0.7628\t0.2917\t0.4670\n",
"Mean\t0.7628\t0.2917\t0.4670\n",
"Y.3\t \t\t\n",
"0\t0.2791\t0.6660\t0.5220\n",
"1\t0.7209\t0.3340\t0.4780\n",
"Mean\t0.7209\t0.3340\t0.4780\n",
"\n",
" \t \t \t \t \t \tClass\t \t \n",
"Z1\tZ2\tY.1\tY.2\tY.3\tObsFreq\tModal\t1\t2\n",
"0\t0\t0\t0\t0\t27.0000\t2\t0.1278\t0.8722\n",
"0\t0\t0\t0\t1\t33.0000\t2\t0.4300\t0.5700\n",
"0\t0\t0\t1\t0\t28.0000\t1\t0.5337\t0.4663\n",
"0\t0\t0\t1\t1\t33.0000\t1\t0.8549\t0.1451\n",
"0\t0\t1\t0\t0\t18.0000\t2\t0.4046\t0.5954\n",
"0\t0\t1\t0\t1\t35.0000\t1\t0.7777\t0.2223\n",
"0\t0\t1\t1\t0\t32.0000\t1\t0.8414\t0.1586\n",
"0\t0\t1\t1\t1\t72.0000\t1\t0.9647\t0.0353\n",
"0\t1\t0\t0\t0\t78.0000\t2\t0.0000\t1.0000\n",
"0\t1\t0\t0\t1\t40.0000\t2\t0.0000\t1.0000\n",
"0\t1\t0\t1\t0\t29.0000\t2\t0.0000\t1.0000\n",
"0\t1\t0\t1\t1\t13.0000\t2\t0.0000\t1.0000\n",
"0\t1\t1\t0\t0\t40.0000\t2\t0.0000\t1.0000\n",
"0\t1\t1\t0\t1\t18.0000\t2\t0.0000\t1.0000\n",
"0\t1\t1\t1\t0\t16.0000\t2\t0.0000\t1.0000\n",
"0\t1\t1\t1\t1\t6.0000\t2\t0.0000\t1.0000\n",
"1\t0\t0\t0\t0\t23.0000\t2\t0.1793\t0.8207\n",
"1\t0\t0\t0\t1\t18.0000\t1\t0.5293\t0.4707\n",
"1\t0\t0\t1\t0\t25.0000\t1\t0.6305\t0.3695\n",
"1\t0\t0\t1\t1\t34.0000\t1\t0.8978\t0.1022\n",
"1\t0\t1\t0\t0\t13.0000\t1\t0.5033\t0.4967\n",
"1\t0\t1\t0\t1\t19.0000\t1\t0.8391\t0.1609\n",
"1\t0\t1\t1\t0\t28.0000\t1\t0.8878\t0.1122\n",
"1\t0\t1\t1\t1\t75.0000\t1\t0.9760\t0.0240\n",
"1\t1\t0\t0\t0\t75.0000\t2\t0.0000\t1.0000\n",
"1\t1\t0\t0\t1\t39.0000\t2\t0.0000\t1.0000\n",
"1\t1\t0\t1\t0\t45.0000\t2\t0.0000\t1.0000\n",
"1\t1\t0\t1\t1\t14.0000\t2\t0.0000\t1.0000\n",
"1\t1\t1\t0\t0\t35.0000\t2\t0.0000\t1.0000\n",
"1\t1\t1\t0\t1\t22.0000\t2\t0.0000\t1.0000\n",
"1\t1\t1\t1\t0\t10.0000\t2\t0.0000\t1.0000\n",
"1\t1\t1\t1\t1\t7.0000\t2\t0.0000\t1.0000\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logistic regression of Y1 on Z1, Z2\n",
"\n",
"This is just some legacy from when I tried it out just doing a dumb logistic regrssion. It worked so I'm not using this anymore."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#x = tf.placeholder(tf.float32, [None, 2])\n",
"##y_obs = tf.placeholder(tf.float32, shape=[None, 1])\n",
"\n",
"#a = bias_variable([1])\n",
"#b = weight_variable([2, 1])\n",
"\n",
"#y_pred = tf.nn.sigmoid(tf.matmul(x, b) + a)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Latent class model with two covariates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Performs streaming updates of a LCM. The approach is to simply specify the likelihood and let tensorflow do the heavy lifting. Afterwards we can obtain the first and second derivatives, even of parameters not updated in the model. Although TensorFlow makes that more difficult than Theano does."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Model definition"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#tf.reset_default_graph()\n",
"\n",
"#sess.close()\n",
"sess = tf.Session()\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Observed variables\n",
"x = tf.placeholder(tf.float32, [None, 2]) # Covariates\n",
"y_obs = tf.placeholder(tf.float32, shape=[None, 3]) # Dependent \"items\"\n",
"\n",
"# Parameters\n",
"a = bias_variable([1]) # LC logistic intercept\n",
"b = weight_variable([2, 1]) # LC logistic slopes wrt Z's\n",
"\n",
"tau = bias_variable([3]) # Item logistic intercepts\n",
"lam = weight_variable([1, 3]) # Item logistic slopes wrt LC\n",
"\n",
"# P(X | Z)\n",
"eta_pred = tf.nn.sigmoid(tf.matmul(x, b) + a)\n",
"\n",
"# Takes a prediction for Y=1 and transforms it to \n",
"# a prediction of Y=1 whereever Y=1 and\n",
"# a prediction of Y=0 whereever Y=0 \n",
"def transform_pred1_to_lik(p):\n",
" return((y_obs * p) + ((1-y_obs) * (1-p)))\n",
"\n",
"# P(Y_j | X)\n",
"# Could include a zero effect of Z1 here and not update to get derivs\n",
"y_pred1 = transform_pred1_to_lik(tf.nn.sigmoid(lam + tau)) \n",
"y_pred2 = transform_pred1_to_lik(tf.nn.sigmoid(tau))\n",
"\n",
"# P(Y | X, Z) = P(Y | X)\n",
"# Takes the prediciton for each item and applies conditional independence rule to yield joint\n",
"ones = np.array([[1,],[1,],[1,],], dtype = np.float32)\n",
"y_pred1_joint = tf.exp(tf.matmul(tf.log(y_pred1), ones))\n",
"y_pred2_joint = tf.exp(tf.matmul(tf.log(y_pred2), ones))\n",
"\n",
"# P(Y | Z)\n",
"# Mixture model\n",
"y_joint = (eta_pred * y_pred1_joint) + ((1 - eta_pred) * y_pred2_joint)\n",
"\n",
"# P(X | Y, Z)\n",
"# Posterior for class 1\n",
"eta_post = (eta_pred * y_pred1_joint) / y_joint\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Objective and optimization definition"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"min2ll = -2*tf.reduce_sum(tf.log(y_joint)) # independent obs's\n",
"#cross_entropy = -tf.reduce_sum(y_obs * tf.log(y_pred))\n",
"\n",
"#train_step = tf.train.AdamOptimizer().minimize(min2ll)\n",
"#train_step = tf.train.GradientDescentOptimizer(0.01).minimize(min2ll)\n",
"#train_step = tf.train.RMSPropOptimizer(.1).minimize(min2ll)\n",
"\n",
"global_step = tf.Variable(0, trainable=False)\n",
"\n",
"starter_learning_rate = 0.1\n",
"learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,\n",
" 100, 0.96, staircase=True)\n",
"train_step = tf.train.AdamOptimizer(learning_rate).minimize(min2ll, global_step=global_step)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"init = tf.initialize_all_variables()\n",
"sess.run(init)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Summaries for TensorBoard. This breaks when rerunnign the Jupyter cells several times but should work the first time. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Build the summary operation based on the TF collection of Summaries.\n",
"tf.train.write_graph(sess.graph_def, '/tmp/lca_logs','graph.pbtxt')\n",
"\n",
"tf.histogram_summary(\"a:\", a)\n",
"tf.histogram_summary('b', b)\n",
"tf.histogram_summary('tau', tau)\n",
"tf.histogram_summary('lam', lam)\n",
"tf.scalar_summary('-2*log-likelihood', min2ll)\n",
"\n",
"summary_op = tf.merge_all_summaries()\n",
"summary_writer = tf.train.SummaryWriter('/tmp/lca_logs',sess.graph_def)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Actually running the optimization"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"chunk_size = 1000\n",
"epochs = 1 # One is enough for the 10 Million big dataset; for 1000 records, 10 epochs are needed\n",
"\n",
"#boot_size = 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This makes use of Pandas option to read in files chunk by chunk so I don't have to first read everything into memory and then subselect certain pieces of it. Could use `skiprows = test_size` to denominate some part the validation set. Or just use a separate file. `dat_big` has 10 million rows and `dat` has 1000. "
]
},
{
"cell_type": "code",
"execution_count": 916,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Not possible with big data set\n",
"#feed_full = {x: df[['Z1', 'Z2',]], y_obs: df[['Y.1','Y.2','Y.3',]]}"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Epoch: Obs: -2LL: a: b[0]: b[1]:\n",
" 0 0.0e+00 3770.68 -0.9515 -0.2136 4.8678\n",
" 0 5.0e+05 3788.94 -1.0114 -0.3555 6.0652\n",
" 0 1.0e+06 3814.81 -0.9851 -0.1853 6.4288\n",
" 0 1.5e+06 3830.98 -1.0326 -0.3319 6.6422\n",
" 0 2.0e+06 3779.44 -1.0616 -0.3019 6.7667\n",
" 0 2.5e+06 3825.46 -0.9909 -0.3216 6.9539\n",
" 0 3.0e+06 3843.16 -1.0418 -0.4212 7.1230\n",
" 0 3.5e+06 3866.26 -0.9937 -0.2220 7.3417\n",
" 0 4.0e+06 3834.61 -0.9681 -0.3630 7.3372\n",
" 0 4.5e+06 3796.56 -1.0323 -0.2355 7.4924\n",
" 0 5.0e+06 3857.98 -1.0719 -0.2346 7.3724\n",
" 0 5.5e+06 3777.30 -1.0074 -0.2907 7.4575\n",
" 0 6.0e+06 3819.84 -0.9815 -0.3449 7.5354\n",
" 0 6.5e+06 3848.10 -0.9803 -0.3837 7.7356\n",
" 0 7.0e+06 3781.37 -1.1016 -0.3132 7.7613\n",
" 0 7.5e+06 3841.09 -1.0411 -0.3044 7.8689\n",
" 0 8.0e+06 3841.88 -1.0126 -0.3186 7.8211\n",
" 0 8.5e+06 3797.86 -0.9804 -0.3154 7.8435\n",
" 0 9.0e+06 3851.94 -0.9967 -0.3104 7.8968\n",
" 0 9.5e+06 3822.74 -0.9823 -0.2963 7.9561\n",
" 0 1.0e+07 3818.26 -0.9748 -0.2913 7.9589\n"
]
}
],
"source": [
"def print_it(feed_chunk):\n",
" LL = sess.run(min2ll, feed_dict=feed_chunk)\n",
" #g = sess.run(tf.gradients(min2ll, b), feed_dict=feed_chunk)[0]\n",
" print(\"{:10d}{:10.1e}{:10.2f}{:10.4f}{:10.4f}{:10.4f}\".format(j, i*chunk_size, LL, sess.run(a)[0], \n",
" sess.run(b)[0][0], sess.run(b)[1][0])) #, g[0][0], g[1][0]))\n",
"\n",
"\n",
"print(\"{:>10s}{:>10s}{:>10s}{:>10s}{:>10s}{:>10s}\".format(\"Epoch:\", \"Obs:\", \"-2LL:\", \"a:\", \"b[0]:\", \"b[1]:\"))\n",
" #, \"dlL/db[0]\", \"dlL/db[1]\"))\n",
" \n",
"for j in range(epochs):\n",
" \n",
" # Need to read it in again (?) to rewind the file\n",
" df = pd.read_csv(\"/Users/daob/Downloads/tensorflow/dat_big.csv\", chunksize = chunk_size, iterator = True)\n",
" i = 0\n",
" \n",
" for chunk in df:\n",
" #start, end = (i * batch_size, (i + 1) * batch_size)\n",
" #xi = np.asarray(df[['Z1', 'Z2',]][start:end], dtype = \"float32\")\n",
" #yi = np.asarray(df[['Y.1','Y.2','Y.3',]][start:end], dtype = \"float32\")\n",
" #wi = np.array([np.random.poisson(1, batch_size) for i in range(boot_size)])\n",
" \n",
" xi = np.asarray(chunk[['Z1', 'Z2',]], dtype = \"float32\")\n",
" yi = np.asarray(chunk[['Y.1','Y.2','Y.3',]], dtype = \"float32\")\n",
" \n",
" feed_chunk = {x: xi, y_obs: yi}\n",
" sess.run(train_step, feed_dict = feed_chunk)\n",
" \n",
" # TensorBoard stuff\n",
" summary_str = sess.run(summary_op, feed_dict = feed_chunk)\n",
" summary_writer.add_summary(summary_str, i)\n",
" \n",
" if (i % 500 == 0):\n",
" print_it(feed_chunk)\n",
" \n",
" #if i >= 100:\n",
" # break # DEBUG\n",
" \n",
" i += 1\n",
"\n",
"print_it(feed_chunk)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Some output"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.9747526]\n",
"[[-0.29132286]\n",
" [ 7.95889664]]\n",
"[ 0.85168254 0.84284556 0.84828174]\n",
"[[-1.69697344 -1.69441473 -1.6916467 ]]\n"
]
}
],
"source": [
"print sess.run(a)\n",
"print sess.run(b)\n",
"print sess.run(tau)\n",
"print sess.run(lam)\n",
"\n",
"# {R} Y~X : -0.8516 1.7104 \n",
"# {R} X ~ Z1 + Z2: 1.021 0.294 -21.742 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"These parameter estimates give the right answers. A small feature here is that TF apparently does some stabilization of logit coefficients. E.g. instead of getting 99 we get 8 for the covariate effect.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.98837703]\n",
" [ 0.10787231]\n",
" [ 0.99946773]\n",
" [ 0.99784988]\n",
" [ 0.99946773]\n",
" [ 0.10787231]\n",
" [ 0.99946916]\n",
" [ 0.39688635]\n",
" [ 0.1395804 ]\n",
" [ 0.46762258]\n",
" [ 0.99990243]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.99946916]\n",
" [ 0.99784988]\n",
" [ 0.39627409]\n",
" [ 0.99946916]\n",
" [ 0.10760621]\n",
" [ 0.9996022 ]\n",
" [ 0.99946773]\n",
" [ 0.9999271 ]\n",
" [ 0.99946773]\n",
" [ 0.98837703]\n",
" [ 0.10760621]\n",
" [ 0.9999271 ]\n",
" [ 0.10811879]\n",
" [ 0.13927341]\n",
" [ 0.9999271 ]\n",
" [ 0.99947059]\n",
" [ 0.10760621]\n",
" [ 0.9999271 ]\n",
" [ 0.99960315]\n",
" [ 0.9996022 ]\n",
" [ 0.10760621]\n",
" [ 0.13927341]\n",
" [ 0.13927341]\n",
" [ 0.9999271 ]\n",
" [ 0.1389419 ]\n",
" [ 0.10760621]\n",
" [ 0.99946916]\n",
" [ 0.8273958 ]\n",
" [ 0.99710965]\n",
" [ 0.02173034]\n",
" [ 0.99960315]\n",
" [ 0.78176117]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.02173034]\n",
" [ 0.99947059]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.99946773]\n",
" [ 0.39754918]\n",
" [ 0.99946916]\n",
" [ 0.99990243]\n",
" [ 0.8273958 ]\n",
" [ 0.39754918]\n",
" [ 0.99946773]\n",
" [ 0.10811879]\n",
" [ 0.99947059]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.02173034]\n",
" [ 0.99947059]\n",
" [ 0.4682596 ]\n",
" [ 0.02886732]\n",
" [ 0.99710965]\n",
" [ 0.99990243]\n",
" [ 0.78176117]\n",
" [ 0.99946773]\n",
" [ 0.99946916]\n",
" [ 0.4682596 ]\n",
" [ 0.39754918]\n",
" [ 0.99946916]\n",
" [ 0.02886732]\n",
" [ 0.39754918]\n",
" [ 0.1389419 ]\n",
" [ 0.02173034]\n",
" [ 0.99947059]\n",
" [ 0.99946773]\n",
" [ 0.99783838]\n",
" [ 0.99946773]\n",
" [ 0.1395804 ]\n",
" [ 0.99711758]\n",
" [ 0.10787231]\n",
" [ 0.39754918]\n",
" [ 0.99990243]\n",
" [ 0.10760621]\n",
" [ 0.02173034]\n",
" [ 0.1395804 ]\n",
" [ 0.4682596 ]\n",
" [ 0.39754918]\n",
" [ 0.1389419 ]\n",
" [ 0.99990243]\n",
" [ 0.99960428]\n",
" [ 0.1389419 ]\n",
" [ 0.99990243]\n",
" [ 0.02886732]\n",
" [ 0.99784988]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99711758]\n",
" [ 0.99947059]\n",
" [ 0.99946916]\n",
" [ 0.99990243]\n",
" [ 0.99947059]\n",
" [ 0.1389419 ]\n",
" [ 0.99946773]\n",
" [ 0.9999271 ]\n",
" [ 0.4682596 ]\n",
" [ 0.02173034]\n",
" [ 0.99947059]\n",
" [ 0.10760621]\n",
" [ 0.46894887]\n",
" [ 0.99784988]\n",
" [ 0.99784988]\n",
" [ 0.02173034]\n",
" [ 0.99990243]\n",
" [ 0.99946773]\n",
" [ 0.99990243]\n",
" [ 0.13927341]\n",
" [ 0.99960315]\n",
" [ 0.99960428]\n",
" [ 0.10760621]\n",
" [ 0.99960428]\n",
" [ 0.46894887]\n",
" [ 0.10811879]\n",
" [ 0.46894887]\n",
" [ 0.02886732]\n",
" [ 0.1389419 ]\n",
" [ 0.10760621]\n",
" [ 0.9996022 ]\n",
" [ 0.99946916]\n",
" [ 0.39754918]\n",
" [ 0.13927341]\n",
" [ 0.99960428]\n",
" [ 0.46762258]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.10787231]\n",
" [ 0.9999271 ]\n",
" [ 0.99784988]\n",
" [ 0.99990243]\n",
" [ 0.99946773]\n",
" [ 0.99784988]\n",
" [ 0.99783838]\n",
" [ 0.02173034]\n",
" [ 0.13927341]\n",
" [ 0.10811879]\n",
" [ 0.1389419 ]\n",
" [ 0.8273958 ]\n",
" [ 0.99990243]\n",
" [ 0.39627409]\n",
" [ 0.99990243]\n",
" [ 0.99783838]\n",
" [ 0.10760621]\n",
" [ 0.99990243]\n",
" [ 0.13927341]\n",
" [ 0.39688635]\n",
" [ 0.39688635]\n",
" [ 0.78176117]\n",
" [ 0.78176117]\n",
" [ 0.4682596 ]\n",
" [ 0.1389419 ]\n",
" [ 0.39688635]\n",
" [ 0.99960428]\n",
" [ 0.99960428]\n",
" [ 0.1389419 ]\n",
" [ 0.10787231]\n",
" [ 0.9999271 ]\n",
" [ 0.9996022 ]\n",
" [ 0.1389419 ]\n",
" [ 0.02173034]\n",
" [ 0.99947059]\n",
" [ 0.99946773]\n",
" [ 0.10760621]\n",
" [ 0.99990243]\n",
" [ 0.39688635]\n",
" [ 0.1395804 ]\n",
" [ 0.1389419 ]\n",
" [ 0.99784428]\n",
" [ 0.10760621]\n",
" [ 0.39754918]\n",
" [ 0.99946773]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.99960428]\n",
" [ 0.02886732]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.10760621]\n",
" [ 0.9999271 ]\n",
" [ 0.99783838]\n",
" [ 0.99947059]\n",
" [ 0.9999271 ]\n",
" [ 0.99946773]\n",
" [ 0.10760621]\n",
" [ 0.99960315]\n",
" [ 0.10760621]\n",
" [ 0.98837703]\n",
" [ 0.46762258]\n",
" [ 0.10811879]\n",
" [ 0.99990243]\n",
" [ 0.99946916]\n",
" [ 0.02886732]\n",
" [ 0.99784428]\n",
" [ 0.99711758]\n",
" [ 0.10787231]\n",
" [ 0.13927341]\n",
" [ 0.46894887]\n",
" [ 0.99712497]\n",
" [ 0.9999271 ]\n",
" [ 0.39627409]\n",
" [ 0.10787231]\n",
" [ 0.99946773]\n",
" [ 0.99960315]\n",
" [ 0.99946916]\n",
" [ 0.10811879]\n",
" [ 0.99783838]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.1389419 ]\n",
" [ 0.02886732]\n",
" [ 0.99990243]\n",
" [ 0.78176117]\n",
" [ 0.46894887]\n",
" [ 0.99990243]\n",
" [ 0.46762258]\n",
" [ 0.99960315]\n",
" [ 0.4682596 ]\n",
" [ 0.02173034]\n",
" [ 0.99960428]\n",
" [ 0.39688635]\n",
" [ 0.99784428]\n",
" [ 0.02886732]\n",
" [ 0.02886732]\n",
" [ 0.99990243]\n",
" [ 0.10811879]\n",
" [ 0.46894887]\n",
" [ 0.46762258]\n",
" [ 0.9996022 ]\n",
" [ 0.4682596 ]\n",
" [ 0.10811879]\n",
" [ 0.13927341]\n",
" [ 0.99946916]\n",
" [ 0.1389419 ]\n",
" [ 0.39688635]\n",
" [ 0.99990243]\n",
" [ 0.99784428]\n",
" [ 0.10811879]\n",
" [ 0.1389419 ]\n",
" [ 0.99710965]\n",
" [ 0.99783838]\n",
" [ 0.02173034]\n",
" [ 0.10811879]\n",
" [ 0.9999271 ]\n",
" [ 0.10787231]\n",
" [ 0.99947059]\n",
" [ 0.10811879]\n",
" [ 0.99990243]\n",
" [ 0.4682596 ]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.10787231]\n",
" [ 0.10787231]\n",
" [ 0.02886732]\n",
" [ 0.46894887]\n",
" [ 0.1389419 ]\n",
" [ 0.10787231]\n",
" [ 0.13927341]\n",
" [ 0.99960315]\n",
" [ 0.02173034]\n",
" [ 0.99784428]\n",
" [ 0.1395804 ]\n",
" [ 0.8273958 ]\n",
" [ 0.99947059]\n",
" [ 0.9999271 ]\n",
" [ 0.99946773]\n",
" [ 0.02173034]\n",
" [ 0.8273958 ]\n",
" [ 0.46894887]\n",
" [ 0.10760621]\n",
" [ 0.9996022 ]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.10760621]\n",
" [ 0.9999271 ]\n",
" [ 0.99711758]\n",
" [ 0.99946773]\n",
" [ 0.99711758]\n",
" [ 0.99947059]\n",
" [ 0.99712497]\n",
" [ 0.99946773]\n",
" [ 0.78176117]\n",
" [ 0.13927341]\n",
" [ 0.1389419 ]\n",
" [ 0.39627409]\n",
" [ 0.39754918]\n",
" [ 0.02886732]\n",
" [ 0.1389419 ]\n",
" [ 0.9996022 ]\n",
" [ 0.99960428]\n",
" [ 0.10760621]\n",
" [ 0.13927341]\n",
" [ 0.39688635]\n",
" [ 0.02886732]\n",
" [ 0.99960428]\n",
" [ 0.9999271 ]\n",
" [ 0.39754918]\n",
" [ 0.13927341]\n",
" [ 0.99947059]\n",
" [ 0.02173034]\n",
" [ 0.9999271 ]\n",
" [ 0.02173034]\n",
" [ 0.02886732]\n",
" [ 0.10811879]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.1389419 ]\n",
" [ 0.1389419 ]\n",
" [ 0.99960315]\n",
" [ 0.9996022 ]\n",
" [ 0.99946773]\n",
" [ 0.99712497]\n",
" [ 0.8273958 ]\n",
" [ 0.99960428]\n",
" [ 0.39688635]\n",
" [ 0.1389419 ]\n",
" [ 0.10787231]\n",
" [ 0.99947059]\n",
" [ 0.8273958 ]\n",
" [ 0.02173034]\n",
" [ 0.99960315]\n",
" [ 0.99960428]\n",
" [ 0.1389419 ]\n",
" [ 0.4682596 ]\n",
" [ 0.99960428]\n",
" [ 0.8273958 ]\n",
" [ 0.10760621]\n",
" [ 0.1395804 ]\n",
" [ 0.99946916]\n",
" [ 0.02886732]\n",
" [ 0.99946773]\n",
" [ 0.99990243]\n",
" [ 0.10787231]\n",
" [ 0.78176117]\n",
" [ 0.9999271 ]\n",
" [ 0.8273958 ]\n",
" [ 0.99990243]\n",
" [ 0.8273958 ]\n",
" [ 0.10760621]\n",
" [ 0.78176117]\n",
" [ 0.99960315]\n",
" [ 0.99946916]\n",
" [ 0.39627409]\n",
" [ 0.1395804 ]\n",
" [ 0.8273958 ]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.99946916]\n",
" [ 0.10787231]\n",
" [ 0.99946916]\n",
" [ 0.99946916]\n",
" [ 0.10811879]\n",
" [ 0.10787231]\n",
" [ 0.99710965]\n",
" [ 0.99947059]\n",
" [ 0.9999271 ]\n",
" [ 0.99711758]\n",
" [ 0.10787231]\n",
" [ 0.9999271 ]\n",
" [ 0.46894887]\n",
" [ 0.1389419 ]\n",
" [ 0.13927341]\n",
" [ 0.99960428]\n",
" [ 0.02173034]\n",
" [ 0.99946916]\n",
" [ 0.10760621]\n",
" [ 0.39688635]\n",
" [ 0.46762258]\n",
" [ 0.99990243]\n",
" [ 0.10760621]\n",
" [ 0.1395804 ]\n",
" [ 0.99990243]\n",
" [ 0.39754918]\n",
" [ 0.13927341]\n",
" [ 0.9996022 ]\n",
" [ 0.10811879]\n",
" [ 0.99990243]\n",
" [ 0.46762258]\n",
" [ 0.02173034]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.99946916]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.9996022 ]\n",
" [ 0.02886732]\n",
" [ 0.99946916]\n",
" [ 0.10760621]\n",
" [ 0.4682596 ]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.8273958 ]\n",
" [ 0.10811879]\n",
" [ 0.99960428]\n",
" [ 0.9999271 ]\n",
" [ 0.4682596 ]\n",
" [ 0.13927341]\n",
" [ 0.13927341]\n",
" [ 0.02173034]\n",
" [ 0.1389419 ]\n",
" [ 0.02886732]\n",
" [ 0.99711758]\n",
" [ 0.9999271 ]\n",
" [ 0.99960315]\n",
" [ 0.99960315]\n",
" [ 0.98837703]\n",
" [ 0.99784988]\n",
" [ 0.99711758]\n",
" [ 0.10760621]\n",
" [ 0.39688635]\n",
" [ 0.99960428]\n",
" [ 0.10811879]\n",
" [ 0.9999271 ]\n",
" [ 0.10787231]\n",
" [ 0.78176117]\n",
" [ 0.02173034]\n",
" [ 0.02173034]\n",
" [ 0.10760621]\n",
" [ 0.02173034]\n",
" [ 0.99960428]\n",
" [ 0.99946773]\n",
" [ 0.13927341]\n",
" [ 0.10787231]\n",
" [ 0.39754918]\n",
" [ 0.99947059]\n",
" [ 0.39688635]\n",
" [ 0.78176117]\n",
" [ 0.02173034]\n",
" [ 0.1389419 ]\n",
" [ 0.02886732]\n",
" [ 0.99711758]\n",
" [ 0.99990243]\n",
" [ 0.99960315]\n",
" [ 0.99960315]\n",
" [ 0.46894887]\n",
" [ 0.02173034]\n",
" [ 0.99711758]\n",
" [ 0.78176117]\n",
" [ 0.10787231]\n",
" [ 0.39688635]\n",
" [ 0.10811879]\n",
" [ 0.99710965]\n",
" [ 0.1389419 ]\n",
" [ 0.99990243]\n",
" [ 0.10811879]\n",
" [ 0.9996022 ]\n",
" [ 0.46894887]\n",
" [ 0.99946916]\n",
" [ 0.39627409]\n",
" [ 0.99947059]\n",
" [ 0.1395804 ]\n",
" [ 0.10811879]\n",
" [ 0.1389419 ]\n",
" [ 0.46894887]\n",
" [ 0.9996022 ]\n",
" [ 0.99947059]\n",
" [ 0.9996022 ]\n",
" [ 0.99946916]\n",
" [ 0.4682596 ]\n",
" [ 0.99710965]\n",
" [ 0.99960428]\n",
" [ 0.39627409]\n",
" [ 0.02886732]\n",
" [ 0.99712497]\n",
" [ 0.39754918]\n",
" [ 0.39627409]\n",
" [ 0.02886732]\n",
" [ 0.1389419 ]\n",
" [ 0.99946773]\n",
" [ 0.9999271 ]\n",
" [ 0.46762258]\n",
" [ 0.46894887]\n",
" [ 0.99960315]\n",
" [ 0.99710965]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.99947059]\n",
" [ 0.39754918]\n",
" [ 0.10787231]\n",
" [ 0.4682596 ]\n",
" [ 0.99783838]\n",
" [ 0.02886732]\n",
" [ 0.39754918]\n",
" [ 0.10760621]\n",
" [ 0.99990243]\n",
" [ 0.99960315]\n",
" [ 0.98450828]\n",
" [ 0.46762258]\n",
" [ 0.99990243]\n",
" [ 0.10787231]\n",
" [ 0.46894887]\n",
" [ 0.9999271 ]\n",
" [ 0.99960428]\n",
" [ 0.99947059]\n",
" [ 0.10760621]\n",
" [ 0.99990243]\n",
" [ 0.39754918]\n",
" [ 0.99712497]\n",
" [ 0.99990243]\n",
" [ 0.1395804 ]\n",
" [ 0.10760621]\n",
" [ 0.39688635]\n",
" [ 0.9999271 ]\n",
" [ 0.02886732]\n",
" [ 0.02886732]\n",
" [ 0.10760621]\n",
" [ 0.99710965]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.8273958 ]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.46894887]\n",
" [ 0.99784428]\n",
" [ 0.46762258]\n",
" [ 0.9996022 ]\n",
" [ 0.99960428]\n",
" [ 0.99960428]\n",
" [ 0.02173034]\n",
" [ 0.02886732]\n",
" [ 0.99946916]\n",
" [ 0.99947059]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99960428]\n",
" [ 0.99784428]\n",
" [ 0.10760621]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.99784988]\n",
" [ 0.39627409]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.99712497]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.02886732]\n",
" [ 0.4682596 ]\n",
" [ 0.8273958 ]\n",
" [ 0.46762258]\n",
" [ 0.1395804 ]\n",
" [ 0.99990243]\n",
" [ 0.13927341]\n",
" [ 0.99711758]\n",
" [ 0.99990243]\n",
" [ 0.99960428]\n",
" [ 0.99990243]\n",
" [ 0.02173034]\n",
" [ 0.98450828]\n",
" [ 0.10760621]\n",
" [ 0.99711758]\n",
" [ 0.46762258]\n",
" [ 0.98837703]\n",
" [ 0.9996022 ]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.39754918]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.02173034]\n",
" [ 0.1395804 ]\n",
" [ 0.02173034]\n",
" [ 0.99960428]\n",
" [ 0.10787231]\n",
" [ 0.99960315]\n",
" [ 0.02886732]\n",
" [ 0.99960315]\n",
" [ 0.02886732]\n",
" [ 0.99784988]\n",
" [ 0.99711758]\n",
" [ 0.9999271 ]\n",
" [ 0.39688635]\n",
" [ 0.99960428]\n",
" [ 0.99990243]\n",
" [ 0.99960428]\n",
" [ 0.4682596 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99947059]\n",
" [ 0.4682596 ]\n",
" [ 0.9999271 ]\n",
" [ 0.1389419 ]\n",
" [ 0.46762258]\n",
" [ 0.99784428]\n",
" [ 0.9999271 ]\n",
" [ 0.46894887]\n",
" [ 0.99947059]\n",
" [ 0.99946916]\n",
" [ 0.99784428]\n",
" [ 0.9999271 ]\n",
" [ 0.39754918]\n",
" [ 0.99784428]\n",
" [ 0.99960315]\n",
" [ 0.99960315]\n",
" [ 0.02886732]\n",
" [ 0.39627409]\n",
" [ 0.46894887]\n",
" [ 0.02886732]\n",
" [ 0.99960428]\n",
" [ 0.1389419 ]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.10787231]\n",
" [ 0.10760621]\n",
" [ 0.4682596 ]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.02173034]\n",
" [ 0.10811879]\n",
" [ 0.46894887]\n",
" [ 0.13927341]\n",
" [ 0.02886732]\n",
" [ 0.10760621]\n",
" [ 0.98450828]\n",
" [ 0.10811879]\n",
" [ 0.99946773]\n",
" [ 0.99990243]\n",
" [ 0.02173034]\n",
" [ 0.9999271 ]\n",
" [ 0.8273958 ]\n",
" [ 0.99946916]\n",
" [ 0.9996022 ]\n",
" [ 0.99990243]\n",
" [ 0.46762258]\n",
" [ 0.99947059]\n",
" [ 0.9999271 ]\n",
" [ 0.99784428]\n",
" [ 0.99946916]\n",
" [ 0.02173034]\n",
" [ 0.99946773]\n",
" [ 0.10787231]\n",
" [ 0.9999271 ]\n",
" [ 0.99712497]\n",
" [ 0.99990243]\n",
" [ 0.13927341]\n",
" [ 0.99990243]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.02173034]\n",
" [ 0.1395804 ]\n",
" [ 0.10760621]\n",
" [ 0.46894887]\n",
" [ 0.46894887]\n",
" [ 0.02886732]\n",
" [ 0.10787231]\n",
" [ 0.1395804 ]\n",
" [ 0.99946773]\n",
" [ 0.39627409]\n",
" [ 0.02886732]\n",
" [ 0.99990243]\n",
" [ 0.02886732]\n",
" [ 0.99784988]\n",
" [ 0.99947059]\n",
" [ 0.13927341]\n",
" [ 0.99946916]\n",
" [ 0.99712497]\n",
" [ 0.1389419 ]\n",
" [ 0.99946916]\n",
" [ 0.99990243]\n",
" [ 0.99947059]\n",
" [ 0.9999271 ]\n",
" [ 0.99946773]\n",
" [ 0.9996022 ]\n",
" [ 0.99960315]\n",
" [ 0.8273958 ]\n",
" [ 0.99990243]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.99784428]\n",
" [ 0.13927341]\n",
" [ 0.99960315]\n",
" [ 0.10811879]\n",
" [ 0.9999271 ]\n",
" [ 0.39754918]\n",
" [ 0.99990243]\n",
" [ 0.02886732]\n",
" [ 0.99712497]\n",
" [ 0.9999271 ]\n",
" [ 0.10760621]\n",
" [ 0.99946773]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.39688635]\n",
" [ 0.99947059]\n",
" [ 0.02173034]\n",
" [ 0.9999271 ]\n",
" [ 0.99946916]\n",
" [ 0.99946773]\n",
" [ 0.39754918]\n",
" [ 0.99946916]\n",
" [ 0.39627409]\n",
" [ 0.99960315]\n",
" [ 0.99960428]\n",
" [ 0.99990243]\n",
" [ 0.02173034]\n",
" [ 0.02886732]\n",
" [ 0.1389419 ]\n",
" [ 0.10811879]\n",
" [ 0.9999271 ]\n",
" [ 0.10787231]\n",
" [ 0.1389419 ]\n",
" [ 0.1395804 ]\n",
" [ 0.9996022 ]\n",
" [ 0.99947059]\n",
" [ 0.10811879]\n",
" [ 0.99960428]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.9996022 ]\n",
" [ 0.39688635]\n",
" [ 0.10787231]\n",
" [ 0.13927341]\n",
" [ 0.99783838]\n",
" [ 0.99947059]\n",
" [ 0.10811879]\n",
" [ 0.02886732]\n",
" [ 0.99946916]\n",
" [ 0.02886732]\n",
" [ 0.98837703]\n",
" [ 0.02886732]\n",
" [ 0.1389419 ]\n",
" [ 0.99960428]\n",
" [ 0.10811879]\n",
" [ 0.1389419 ]\n",
" [ 0.99990243]\n",
" [ 0.99960315]\n",
" [ 0.02886732]\n",
" [ 0.99946916]\n",
" [ 0.39627409]\n",
" [ 0.99710965]\n",
" [ 0.02173034]\n",
" [ 0.9996022 ]\n",
" [ 0.9999271 ]\n",
" [ 0.1395804 ]\n",
" [ 0.99960428]\n",
" [ 0.99947059]\n",
" [ 0.10811879]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.02173034]\n",
" [ 0.9999271 ]\n",
" [ 0.46894887]\n",
" [ 0.13927341]\n",
" [ 0.9996022 ]\n",
" [ 0.02173034]\n",
" [ 0.13927341]\n",
" [ 0.02886732]\n",
" [ 0.98837703]\n",
" [ 0.02173034]\n",
" [ 0.4682596 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.39627409]\n",
" [ 0.99712497]\n",
" [ 0.02886732]\n",
" [ 0.99712497]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.99947059]\n",
" [ 0.99710965]\n",
" [ 0.39627409]\n",
" [ 0.9999271 ]\n",
" [ 0.99783838]\n",
" [ 0.4682596 ]\n",
" [ 0.4682596 ]\n",
" [ 0.99946773]\n",
" [ 0.13927341]\n",
" [ 0.39688635]\n",
" [ 0.99784988]\n",
" [ 0.1395804 ]\n",
" [ 0.02173034]\n",
" [ 0.39627409]\n",
" [ 0.9996022 ]\n",
" [ 0.9999271 ]\n",
" [ 0.13927341]\n",
" [ 0.99947059]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.99947059]\n",
" [ 0.10787231]\n",
" [ 0.1389419 ]\n",
" [ 0.10811879]\n",
" [ 0.02886732]\n",
" [ 0.78176117]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.78176117]\n",
" [ 0.1395804 ]\n",
" [ 0.78176117]\n",
" [ 0.39627409]\n",
" [ 0.39688635]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.02173034]\n",
" [ 0.99960315]\n",
" [ 0.99990243]\n",
" [ 0.13927341]\n",
" [ 0.02173034]\n",
" [ 0.02173034]\n",
" [ 0.99711758]\n",
" [ 0.9996022 ]\n",
" [ 0.39688635]\n",
" [ 0.9996022 ]\n",
" [ 0.1389419 ]\n",
" [ 0.1395804 ]\n",
" [ 0.99990243]\n",
" [ 0.1389419 ]\n",
" [ 0.1395804 ]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.02173034]\n",
" [ 0.46762258]\n",
" [ 0.4682596 ]\n",
" [ 0.46762258]\n",
" [ 0.1389419 ]\n",
" [ 0.9999271 ]\n",
" [ 0.78176117]\n",
" [ 0.39754918]\n",
" [ 0.39754918]\n",
" [ 0.46762258]\n",
" [ 0.99990243]\n",
" [ 0.98450828]\n",
" [ 0.39754918]\n",
" [ 0.1389419 ]\n",
" [ 0.02173034]\n",
" [ 0.10811879]\n",
" [ 0.99947059]\n",
" [ 0.99960428]\n",
" [ 0.9996022 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99946773]\n",
" [ 0.02173034]\n",
" [ 0.39688635]\n",
" [ 0.99712497]\n",
" [ 0.99990243]\n",
" [ 0.99783838]\n",
" [ 0.99990243]\n",
" [ 0.99712497]\n",
" [ 0.99990243]\n",
" [ 0.10760621]\n",
" [ 0.02886732]\n",
" [ 0.46894887]\n",
" [ 0.9999271 ]\n",
" [ 0.46762258]\n",
" [ 0.10811879]\n",
" [ 0.13927341]\n",
" [ 0.10760621]\n",
" [ 0.13927341]\n",
" [ 0.02173034]\n",
" [ 0.9996022 ]\n",
" [ 0.9999271 ]\n",
" [ 0.99711758]\n",
" [ 0.99784988]\n",
" [ 0.9999271 ]\n",
" [ 0.9996022 ]\n",
" [ 0.99946773]\n",
" [ 0.99946773]\n",
" [ 0.4682596 ]\n",
" [ 0.99946773]\n",
" [ 0.99946916]\n",
" [ 0.39754918]\n",
" [ 0.9996022 ]\n",
" [ 0.46762258]\n",
" [ 0.9999271 ]\n",
" [ 0.99711758]\n",
" [ 0.39688635]\n",
" [ 0.99990243]\n",
" [ 0.13927341]\n",
" [ 0.99990243]\n",
" [ 0.02173034]\n",
" [ 0.99990243]\n",
" [ 0.9996022 ]\n",
" [ 0.02173034]\n",
" [ 0.99946773]\n",
" [ 0.9999271 ]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.10760621]\n",
" [ 0.78176117]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.46762258]\n",
" [ 0.99784428]\n",
" [ 0.99960428]\n",
" [ 0.99990243]\n",
" [ 0.99784988]\n",
" [ 0.99960315]\n",
" [ 0.13927341]\n",
" [ 0.13927341]\n",
" [ 0.99960428]\n",
" [ 0.99711758]\n",
" [ 0.99946773]\n",
" [ 0.02886732]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.9999271 ]\n",
" [ 0.9999271 ]\n",
" [ 0.46894887]\n",
" [ 0.9996022 ]\n",
" [ 0.9996022 ]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.99990243]\n",
" [ 0.10760621]\n",
" [ 0.02173034]\n",
" [ 0.99946916]\n",
" [ 0.9999271 ]\n",
" [ 0.10760621]\n",
" [ 0.99960428]\n",
" [ 0.46762258]\n",
" [ 0.4682596 ]\n",
" [ 0.78176117]\n",
" [ 0.99960428]\n",
" [ 0.46762258]\n",
" [ 0.9999271 ]\n",
" [ 0.13927341]\n",
" [ 0.98837703]\n",
" [ 0.99784428]\n",
" [ 0.02173034]\n",
" [ 0.10760621]\n",
" [ 0.99947059]\n",
" [ 0.99783838]\n",
" [ 0.99784428]\n",
" [ 0.10760621]\n",
" [ 0.46762258]\n",
" [ 0.9996022 ]\n",
" [ 0.02173034]\n",
" [ 0.39754918]\n",
" [ 0.99711758]\n",
" [ 0.99946916]\n",
" [ 0.02886732]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.99784988]\n",
" [ 0.10760621]\n",
" [ 0.78176117]\n",
" [ 0.99946773]\n",
" [ 0.39688635]\n",
" [ 0.10811879]\n",
" [ 0.98837703]\n",
" [ 0.8273958 ]\n",
" [ 0.99990243]\n",
" [ 0.99990243]\n",
" [ 0.8273958 ]\n",
" [ 0.39688635]\n",
" [ 0.46762258]\n",
" [ 0.02173034]\n",
" [ 0.02886732]\n",
" [ 0.02173034]\n",
" [ 0.99990243]\n",
" [ 0.99712497]\n",
" [ 0.39627409]\n",
" [ 0.99960428]\n",
" [ 0.99946773]\n",
" [ 0.9999271 ]\n",
" [ 0.39754918]\n",
" [ 0.1389419 ]\n",
" [ 0.9996022 ]\n",
" [ 0.99946916]\n",
" [ 0.99960428]\n",
" [ 0.8273958 ]\n",
" [ 0.99712497]\n",
" [ 0.10811879]\n",
" [ 0.1395804 ]\n",
" [ 0.99947059]\n",
" [ 0.99946773]\n",
" [ 0.46894887]\n",
" [ 0.99947059]\n",
" [ 0.9996022 ]\n",
" [ 0.99784428]\n",
" [ 0.9996022 ]\n",
" [ 0.10760621]\n",
" [ 0.39754918]\n",
" [ 0.99960315]\n",
" [ 0.10787231]\n",
" [ 0.46894887]\n",
" [ 0.9999271 ]]\n"
]
}
],
"source": [
"print sess.run(eta_post, feed_dict = feed_chunk)\n",
"\n",
"# The first observation's (1000 dataset) should be about 0.85 and the second about 0.10.\n",
"# The last observation's (1000 dataset) should be about 0.1451 and the secondtolast about 0.1022 and the thirdtolast 0.22 and 5th 1.000.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### EPC's\n",
"\n",
"I thought it would be really easy to get derivatives and Hessians for all parameters, including new ones. It's not, mostly because TensorFlow (unlike Theano!) makes it difficult for the user to obtain these for sets of parameters together. You need to manually calculate all the off-diagonals of the Hessian, and combine all the results manually as well. A pain. But possible, so here is a proof of concept, calculating an (approcximate) EPC for the direct effects of Z1 and Z2 on the items. It is still a little strange because the hypothetical alternative model has a delta that is different for each Z but the same over Y's. \n",
"\n",
"Maybe the best idea would be to calculate these on the test set since they are a kind of model fit criterion.\n"
]
},
{
"cell_type": "code",
"execution_count": 697,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"delta = tf.Variable(np.asarray([[0.],[0.]], dtype = np.float32)) # Item bias logistic slopes wrt Z's\n",
"\n",
"# New, equivalent, version of the model in which there is item bias delta but it's fixed to zero:\n",
"y_pred1 = transform_pred1_to_lik(tf.nn.sigmoid(lam + tau + tf.matmul(x, delta))) \n",
"y_pred2 = transform_pred1_to_lik(tf.nn.sigmoid(tau + tf.matmul(x, delta)))\n",
"\n",
"# P(Y | X, Z) = P(Y | X)\n",
"# Takes the prediciton for each item and applies conditional independence rule to yield joint\n",
"ones = np.array([[1,],[1,],[1,],], dtype = np.float32)\n",
"y_pred1_joint = tf.exp(tf.matmul(tf.log(y_pred1), ones))\n",
"y_pred2_joint = tf.exp(tf.matmul(tf.log(y_pred2), ones))\n",
"\n",
"# P(Y | Z)\n",
"# Mixture model\n",
"y_joint = (eta_pred * y_pred1_joint) + ((1 - eta_pred) * y_pred2_joint)\n",
"\n",
"# P(X | Y, Z)\n",
"# Posterior for class 1\n",
"eta_post = (eta_pred * y_pred1_joint) / y_joint\n",
"\n",
"min2ll = -2*tf.reduce_sum(tf.log(y_joint)) # independent obs's\n",
"\n",
"# Only initialize the new parameter\n",
"init_new = tf.initialize_variables([delta,])\n",
"sess.run(init_new)\n"
]
},
{
"cell_type": "code",
"execution_count": 724,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-1.0839479 0.51302063 1.17497635 1.11338735 -1.55259323 -1.98827338\n",
" -1.55796206 -0.39906496 6.30396605]\n",
"[-1.0839479 0.51302063 1.17497635 1.11338735 -1.55259323 -1.98827338\n",
" -1.55796206 -0.39906496 6.30396605 0. 0. ]\n"
]
}
],
"source": [
"theta = tf.concat(0,[a, tau, tf.reshape(lam, [-1]), tf.reshape(b, [-1])])\n",
"theta_aug = tf.concat(0, [theta, tf.reshape(delta, [-1])])\n",
"\n",
"print sess.run(theta)\n",
"print sess.run(theta_aug)"
]
},
{
"cell_type": "code",
"execution_count": 761,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3861.33\n",
"g_delta:\n",
"[[ 23.57310677]\n",
" [ 34.95653152]]\n",
"H_delta:\n",
"[[ 841.83892822]\n",
" [ 932.23809814]]\n",
"d delta0 / d delta:\n",
"313.186\n"
]
}
],
"source": [
"# Check that -2lL is still the same\n",
"# These are not allowed, not sure how to accomplish this:\n",
"#grad_theta = tf.gradients(min2ll, theta)\n",
"#grad_theta_aug = tf.gradients(min2ll, theta_aug)\n",
"\n",
"print sess.run(min2ll, feed_dict=feed_full) \n",
"\n",
"# Unfortunately TensorFlow (unlike Theano!) does not support proper Jacobians/Hessians, \n",
"# so this needs to be done by hand :*(\n",
"g_a = tf.gradients(min2ll, a)\n",
"g_b = tf.gradients(min2ll, b)\n",
"g_tau = tf.gradients(min2ll, tau)\n",
"g_lam = tf.gradients(min2ll, lam)\n",
"g_delta = tf.gradients(min2ll, delta)\n",
"\n",
"g = [g_a, g_b, g_tau, g_lam, g_delta]\n",
"#print [sess.run(g_vec, feed_dict=feed_full) for g_vec in g]\n",
"\n",
"print \"g_delta:\"\n",
"g_delta_val = sess.run(g_delta, feed_dict=feed_full)[0]\n",
"print g_delta_val\n",
"print \"H_delta:\"\n",
"H_delta_val = sess.run(tf.gradients(g_delta, delta), feed_dict=feed_full)[0]\n",
"print H_delta_val\n",
"print \"d delta0 / d delta:\"\n",
"H_delta_val_01 = sess.run(tf.gradients(tf.slice(g_delta[0], [0,0], [1,1]), delta), feed_dict=feed_full)[0][1][0]\n",
"print H_delta_val_01\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 771,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.01605898]\n",
" [ 0.0321024 ]]\n"
]
}
],
"source": [
"H = np.array([[H_delta_val[0][0], H_delta_val_01],[H_delta_val_01, H_delta_val[1][0]]])\n",
"\n",
"epc_approx = np.matmul(np.linalg.inv(H), g_delta_val)\n",
"print epc_approx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These should be small since the population has been generated with them being zero. \n",
"\n",
"\n",
"**TODO**: \n",
"\n",
" * Here I'm using the full dataset to calculate the gradients, which is \"cheating\". It should be possible to already include delta in the model above but tell the optimizer not to update this Variable. Then a running update can be made of the gradient and hessian of all parameters, including these fixed ones. \n",
"\n",
" * Online Bootstrapping with Poisson weights (do everything once for each bs weight, only changing that -2lL * wi is used)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment