Skip to content

Instantly share code, notes, and snippets.

@tillahoffmann
Created November 30, 2016 09:51
Show Gist options
  • Save tillahoffmann/f0f9f43204999a0fb83c9755332e7c97 to your computer and use it in GitHub Desktop.
Save tillahoffmann/f0f9f43204999a0fb83c9755332e7c97 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Generate some synthetic data for logistic regression\n",
"np.random.seed(1)\n",
"num_dims = 5\n",
"num_samples = 1000\n",
"\n",
"\n",
"design_matrix = np.random.normal(0, 1, (num_samples, num_dims))\n",
"coefficients = np.random.normal(0, 1, num_dims)\n",
"predictors = np.dot(design_matrix, coefficients)\n",
"probabilities = 1 / (1 + np.exp(-predictors))\n",
"observations = np.random.uniform(0, 1, num_samples) < probabilities"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Create a model\n",
"with tf.Graph().as_default() as graph:\n",
" tf_design_matrix = tf.placeholder(tf.float32, (num_samples, num_dims))\n",
" tf_coefficients = tf.Variable(np.random.normal(0, 1, num_dims), dtype=tf.float32)\n",
" tf_predictors = tf.reduce_sum(tf_design_matrix * tf_coefficients, 1)\n",
" tf_observations = tf.placeholder(tf.float32, num_samples)\n",
" tf_loss = tf.nn.sigmoid_cross_entropy_with_logits(tf_predictors, tf_observations)\n",
" \n",
" tf_hessian, = tf.hessians(tf_loss, tf_coefficients)\n",
" tf_train_op = tf.train.AdamOptimizer(.1).minimize(tf_loss)\n",
" tf_init_op = tf.global_variables_initializer()\n",
" \n",
"session = tf.Session(graph=graph)\n",
"session.run(tf_init_op)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x1006f9390>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFztJREFUeJzt3X2QXfV93/H39967qwcMSKCFCglXci0/ECapqYqx22Zc\nY/PgeCz/EU9h0lp1mNG0JYkbt7Vh/IemzngmaTOhoXWZUCMDHQ+EUDdoPCSEYrv+pzwI28E8hjU4\naHmwlghkAgjtSt/+cX8rrnbvvavdu8tF57xfM3f23t/53Xt/R0ezn/09nHMiM5Ek1U9j2A2QJA2H\nASBJNWUASFJNGQCSVFMGgCTVlAEgSTVlAEhSTRkAklRTBoAk1VRr2A3oZ926dblp06ZhN0OSTigP\nPvjgi5k5Nl+9t3UAbNq0iT179gy7GZJ0QomIvz6eeg4BSVJNGQCSVFMGgCTV1LwBEBG7ImJfRDw8\nq/w3I+KJiHgkIv5TR/nVETFetl3cUX5JKRuPiKuWdjckSQt1PJPANwL/Dbh5piAi/imwDfjFzHwj\nIs4o5ecAlwG/AJwF/J+IeE9529eAjwMTwAMRsTszH12qHZEkLcy8AZCZ34+ITbOK/zXwu5n5Rqmz\nr5RvA24t5U9HxDhwftk2nplPAUTEraWuASBJQ7LYOYD3AP8kIu6LiP8bEf+wlG8A9nbUmyhlvcrn\niIgdEbEnIvZMTk4usnmSpPksNgBawFrgAuA/ALdFRADRpW72KZ9bmHl9Zm7NzK1jY/Oex9DVq29M\n8wd/8QQ/fOalRb1fkupgsQEwAXwr2+4HjgDrSvnZHfU2As/1KV8Wb0wf4drvjPPQxIHl+gpJOuEt\nNgD+FPgoQJnkHQVeBHYDl0XEiojYDGwB7gceALZExOaIGKU9Ubx70Mb30my0OxxTh48s11dI0glv\n3kngiLgF+AiwLiImgJ3ALmBXWRp6CNiemQk8EhG30Z7cnQauzMzD5XN+A7gLaAK7MvORZdgfAEaa\n7QA4fKTrKJMkieNbBXR5j03/vEf9rwJf7VJ+J3Dnglq3SDM9gGkDQJJ6quSZwK1Ge7emDxsAktRL\nJQOg2Qgi4PAR5wAkqZdKBgBAqxFMOQQkST1VOAAaTgJLUh8VDoBwGagk9VHdAGiGPQBJ6qOyAdBs\nNJhyFZAk9VTZAGg1wlVAktRHdQOgGZ4IJkl9VDcAGuGJYJLUR3UDoOkyUEnqp7oB4DJQSeqrugHg\nMlBJ6quyAdBsNLwUhCT1UdkAcBmoJPVX6QBwFZAk9VbdAPA8AEnqq7oB0GgYAJLUR4UDIJh2Gagk\n9VTdAHAZqCT1Vd0AaDQ8EUyS+qhsADQb9gAkqZ/KBoCrgCSpv+oGgOcBSFJf1Q2ApstAJamfeQMg\nInZFxL6IeLjLtn8fERkR68rriIhrI2I8Ih6KiPM66m6PiCfLY/vS7sZcrUYw7aUgJKmn4+kB3Ahc\nMrswIs4GPg4801F8KbClPHYA15W6pwE7gQ8C5wM7I2LtIA2fT6vR4LBDQJLU07wBkJnfB/Z32XQN\n8EWg87fsNuDmbLsXWBMR64GLgbszc39mvgTcTZdQWUqtZjBlD0CSelrUHEBEfAp4NjP/ctamDcDe\njtcTpaxX+bJxGagk9dda6BsiYjXwZeCibpu7lGWf8m6fv4P28BHvfOc7F9q8o0YaLgOVpH4W0wP4\ne8Bm4C8j4qfARuAHEfF3aP9lf3ZH3Y3Ac33K58jM6zNza2ZuHRsbW0Tz2pqNBpnYC5CkHhYcAJn5\n48w8IzM3ZeYm2r/cz8vMF4DdwGfLaqALgAOZ+TxwF3BRRKwtk78XlbJl02q2Ox2uBJKk7o5nGegt\nwP8D3hsRExFxRZ/qdwJPAePA/wD+DUBm7gd+B3igPL5SypZNq1ECwJVAktTVvHMAmXn5PNs3dTxP\n4Moe9XYBuxbYvkVrNdvZ5jyAJHVX3TOBj/YAHAKSpG4qGwDNEgBOAktSd5UNgJGjk8AGgCR1U9kA\naDbKHICTwJLUVWUDYMRloJLUV2UDYGYOwCEgSequsgHQcghIkvqqcAA4BCRJ/VQ2AJquApKkviob\nACNlCMjzACSpu8oGwMwk8JRnAktSV5UNgJlloPYAJKm7ygZA06uBSlJflQ2AEa8GKkl9VTYAml4N\nVJL6qmwAtDwTWJL6qm4ANF0GKkn9VDcAXAYqSX1VNwBcBipJfVU2AI6eCGYASFJXlQ2Ao5eCcAhI\nkrqqbAB4MThJ6q+yAeAyUEnqr8IB4DJQSeqnwgHgMlBJ6mfeAIiIXRGxLyIe7ij7zxHxeEQ8FBH/\nOyLWdGy7OiLGI+KJiLi4o/ySUjYeEVct/a4cq9EIGmEPQJJ6OZ4ewI3AJbPK7gbOzcxfBP4KuBog\nIs4BLgN+obznv0dEMyKawNeAS4FzgMtL3WXVajSY8mqgktTVvAGQmd8H9s8q+4vMnC4v7wU2lufb\ngFsz843MfBoYB84vj/HMfCozDwG3lrrLqtUMDntPYEnqainmAH4d+LPyfAOwt2PbRCnrVb6smo2w\nByBJPQwUABHxZWAa+OZMUZdq2ae822fuiIg9EbFncnJykObRaoRzAJLUw6IDICK2A58Efi0zZ37L\nTgBnd1TbCDzXp3yOzLw+M7dm5taxsbHFNg9oXxHU8wAkqbtFBUBEXAJ8CfhUZr7WsWk3cFlErIiI\nzcAW4H7gAWBLRGyOiFHaE8W7B2v6/FqN8IYwktRDa74KEXEL8BFgXURMADtpr/pZAdwdEQD3Zua/\nysxHIuI24FHaQ0NXZubh8jm/AdwFNIFdmfnIMuzPMVrNsAcgST3MGwCZeXmX4hv61P8q8NUu5XcC\ndy6odQMaaTQ8EUySeqjsmcDQvjH8tKuAJKmrSgdAqxn2ACSph0oHwEiz4Q1hJKmHigdAMDVtD0CS\nuql4ADSY9lIQktRVpQOg1WxwyElgSeqq0gEw2vREMEnqpdIB0PI8AEnqqdIBMNLyPABJ6qXaAdAI\nDtkDkKSuqh0AngksST1VOgA8E1iSeqt0AIw0Gw4BSVIPFQ+AcAhIknqoeAC4DFSSeql0AMzcEvLN\nO1ZKkmZUOgBGm+170U85DCRJc1Q6AFrN9u55QThJmqvSATBSAmBq2h6AJM1W8QAoQ0D2ACRpjooH\nQOkBuBJIkuaoRQB4LoAkzVXxAGgPAXk2sCTNVfEAsAcgSb1UOgBajZnzAOwBSNJs8wZAROyKiH0R\n8XBH2WkRcXdEPFl+ri3lERHXRsR4RDwUEed1vGd7qf9kRGxfnt051kjLSWBJ6uV4egA3ApfMKrsK\nuCcztwD3lNcAlwJbymMHcB20AwPYCXwQOB/YORMay2mkMRMADgFJ0mzzBkBmfh/YP6t4G3BTeX4T\n8OmO8puz7V5gTUSsBy4G7s7M/Zn5EnA3c0Nlyc1MAntjeEmaa7FzAGdm5vMA5ecZpXwDsLej3kQp\n61W+rGYuBeEqIEmaa6kngaNLWfYpn/sBETsiYk9E7JmcnByoMaOuApKknhYbAD8rQzuUn/tK+QRw\ndke9jcBzfcrnyMzrM3NrZm4dGxtbZPPaWk1XAUlSL4sNgN3AzEqe7cAdHeWfLauBLgAOlCGiu4CL\nImJtmfy9qJQtq6OXgjhiD0CSZmvNVyEibgE+AqyLiAnaq3l+F7gtIq4AngE+U6rfCXwCGAdeAz4H\nkJn7I+J3gAdKva9k5uyJ5SV39GJw0/YAJGm2eQMgMy/vsenCLnUTuLLH5+wCdi2odQMa8X4AktRT\ntc8EPnotIIeAJGm2SgfAm6uA7AFI0myVDoCW9wOQpJ4qHQAj3hReknqqdgA07AFIUi+VDoBGI2g2\nwgCQpC4qHQDQvieAQ0CSNFflA2C01eCQJ4JJ0hyVD4AVrYZXA5WkLmoQAE17AJLUReUDYLTV4A0D\nQJLmqH4ANBscmj487GZI0ttO5QNgxYg9AEnqpvIB0O4BGACSNFvlA8AegCR1V/kAsAcgSd1VPgBc\nBipJ3VU+ANrLQF0FJEmz1SIA7AFI0lyVD4AVnggmSV1VPgDsAUhSd5UPgBWtpj0ASeqi8gEwWq4G\nmuk9ASSpU+UDYEWrvYteElqSjlWbAHAYSJKONVAARMRvR8QjEfFwRNwSESsjYnNE3BcRT0bEH0fE\naKm7orweL9s3LcUOzGd0pgdgAEjSMRYdABGxAfgtYGtmngs0gcuA3wOuycwtwEvAFeUtVwAvZea7\ngWtKvWVnD0CSuht0CKgFrIqIFrAaeB74KHB72X4T8OnyfFt5Tdl+YUTEgN8/L3sAktTdogMgM58F\nfh94hvYv/gPAg8DLmTldqk0AG8rzDcDe8t7pUv/0xX7/8VrRagIGgCTNNsgQ0Fraf9VvBs4CTgIu\n7VJ1Zv1lt7/256zNjIgdEbEnIvZMTk4utnlHjTZnhoC8HpAkdRpkCOhjwNOZOZmZU8C3gA8Da8qQ\nEMBG4LnyfAI4G6BsPxXYP/tDM/P6zNyamVvHxsYGaF6bQ0CS1N0gAfAMcEFErC5j+RcCjwLfBX61\n1NkO3FGe7y6vKdu/k2/B2VlOAktSd4PMAdxHezL3B8CPy2ddD3wJ+EJEjNMe47+hvOUG4PRS/gXg\nqgHafdzsAUhSd635q/SWmTuBnbOKnwLO71L3IPCZQb5vMWYmgZ0DkKRjVf5M4FGHgCSpq8oHwAqH\ngCSpq9oEgD0ASTpW5QPASWBJ6q7yATAzCXzQSWBJOkYNAqC9iwen7AFIUqfKB0CjEawaafL6oen5\nK0tSjVQ+AABWjTZ57ZBDQJLUqR4BMNLk9SkDQJI61SMARpu8bg9Ako5RiwBYPWoPQJJmq0UArBpx\nDkCSZqtHADgEJElz1CIAHAKSpLlqEQCrRlr2ACRplnoEwGjDHoAkzVKLAFg92uI1zwSWpGPUIgBW\njjQ5OHWEI0eW/RbEknTCqEUArB71iqCSNFutAsBzASTpTbUIgJUj7QBwJZAkvakWATDTA3AlkCS9\nqRYBsGrEISBJmq0eATDqEJAkzVaLAFg92gLg9SnPBZCkGQMFQESsiYjbI+LxiHgsIj4UEadFxN0R\n8WT5ubbUjYi4NiLGI+KhiDhvaXZhfg4BSdJcg/YA/hD488x8H/BLwGPAVcA9mbkFuKe8BrgU2FIe\nO4DrBvzu4+YyUEmaa9EBEBGnAL8M3ACQmYcy82VgG3BTqXYT8OnyfBtwc7bdC6yJiPWLbvkCvGNF\newjobw86BCRJMwbpAbwLmAS+ERE/jIivR8RJwJmZ+TxA+XlGqb8B2Nvx/olStuzesbIdAK8YAJJ0\n1CAB0ALOA67LzA8Ar/LmcE830aVszsV5ImJHROyJiD2Tk5MDNO9NI80Gq0aavHJwakk+T5KqYJAA\nmAAmMvO+8vp22oHws5mhnfJzX0f9szvevxF4bvaHZub1mbk1M7eOjY0N0LxjnbyyZQ9AkjosOgAy\n8wVgb0S8txRdCDwK7Aa2l7LtwB3l+W7gs2U10AXAgZmhorfCyStbvPKGPQBJmtEa8P2/CXwzIkaB\np4DP0Q6V2yLiCuAZ4DOl7p3AJ4Bx4LVS9y1z8soRewCS1GGgAMjMHwFbu2y6sEvdBK4c5PsGcfLK\nFj83ACTpqFqcCQwzcwAOAUnSjPoEwAqHgCSpU30CwB6AJB2jNgGwZvUIB6eOcNB7AkgSUKMAWHvS\nKAAvv2YvQJKgRgFw2up2AOx/9dCQWyJJbw/1CYDSA3jpNQNAkqCGAfA39gAkCahRAMzMAbxkAEgS\nUKMAWLNqhAjnACRpRm0CoNVscOqqEQNAkoraBADAmSev5IWfHxx2MyTpbaFWAbB+zUqeP/D6sJsh\nSW8LtQqAs9as4rmX7QFIEtQtAE5dyf5XD3k5CEmiZgGw/tRVADx/wF6AJNUqAN55+moAnn7xb4fc\nEkkavloFwHvOPBmAx194ZcgtkaThq1UAnLpqhLNOXckTBoAk1SsAAN6//hR+/OyBYTdDkoaudgHw\n4Xev46nJV9m7/7VhN0WShqp2AXDh+84A4I4fPTvklkjScLWG3YC32qZ1J/Gx95/Bdd/7CWecspJ3\nn/EORpvtHDySefS2kQenDtNqBitHmqwaabJqtP1zZXmsGmky0gwiYsh7JEmLU7sAAPjKtnP5l9+4\nny/e/tBAn9NsREcoNFjRapxwgXBitVaqj/etP4X/evkHlvU7ahkAZ61ZxZ99/pd57PmfM/nKG0wf\nSaD9y3Dml/nKkSaHjySvTx3m9anDHDx0+M3npZfwemfZocO8MX1kuDu2QEkOuwmSejh77apl/46B\nAyAimsAe4NnM/GREbAZuBU4DfgD8i8w8FBErgJuBfwD8DfDPMvOng37/YjUbwbkbTh3W10vS0C3F\nJPDngcc6Xv8ecE1mbgFeAq4o5VcAL2Xmu4FrSj1J0pAMFAARsRH4FeDr5XUAHwVuL1VuAj5dnm8r\nrynbL4wTbcBckipk0B7AfwG+CMwMfp8OvJyZ0+X1BLChPN8A7AUo2w+U+pKkIVh0AETEJ4F9mflg\nZ3GXqnkc2zo/d0dE7ImIPZOTk4ttniRpHoP0AP4R8KmI+CntSd+P0u4RrImImcnljcBz5fkEcDZA\n2X4qsH/2h2bm9Zm5NTO3jo2NDdA8SVI/iw6AzLw6Mzdm5ibgMuA7mflrwHeBXy3VtgN3lOe7y2vK\n9u9kpusQJWlIluNSEF8CvhAR47TH+G8o5TcAp5fyLwBXLcN3S5KO05KcCJaZ3wO+V54/BZzfpc5B\n4DNL8X2SpMHF23kUJiImgb8e4CPWAS8uUXNOFO5z9dVtf8F9Xqi/m5nzTqK+rQNgUBGxJzO3Drsd\nbyX3ufrqtr/gPi+X2l0OWpLUZgBIUk1VPQCuH3YDhsB9rr667S+4z8ui0nMAkqTeqt4DkCT1UMkA\niIhLIuKJiBiPiMqccBYRZ0fEdyPisYh4JCI+X8pPi4i7I+LJ8nNtKY+IuLb8OzwUEecNdw8WLyKa\nEfHDiPh2eb05Iu4r+/zHETFayleU1+Nl+6ZhtnuxImJNRNweEY+X4/2hqh/niPjt8v/64Yi4JSJW\nVu04R8SuiNgXEQ93lC34uEbE9lL/yYjY3u27jkflAqDcoOZrwKXAOcDlEXHOcFu1ZKaBf5eZ7wcu\nAK4s+3YVcE+5B8M9vHmW9aXAlvLYAVz31jd5ydTtvhN/CPx5Zr4P+CXa+17Z4xwRG4DfArZm5rlA\nk/YlZqp2nG8ELplVtqDjGhGnATuBD9I+6XbnTGgsWGZW6gF8CLir4/XVwNXDbtcy7esdwMeBJ4D1\npWw98ER5/kfA5R31j9Y7kR60Lyp4D+0LDn6b9pVlXwRas485cBfwofK8VerFsPdhgft7CvD07HZX\n+Tjz5uXiTyvH7dvAxVU8zsAm4OHFHlfgcuCPOsqPqbeQR+V6AHTcd6DovCdBZZQu7weA+4AzM/N5\ngPLzjFKtKv8WdbvvxLuASeAbZdjr6xFxEhU+zpn5LPD7wDPA87SP24NU+zjPWOhxXbLjXcUAOK77\nDpzIIuIdwP8C/m1m/rxf1S5lJ9S/xXLdd+JtrgWcB1yXmR8AXqX/xRNP+H0uQxjbgM3AWcBJtIdA\nZqvScZ5Pr31csn2vYgAcve9A0XlPghNeRIzQ/uX/zcz8Vin+WUSsL9vXA/tKeRX+LZblvhNvcxPA\nRGbeV17fTjsQqnycPwY8nZmTmTkFfAv4MNU+zjMWelyX7HhXMQAeALaU1QOjtCeSdg+5TUsiIoL2\nZbUfy8w/6NjUea+F2fdg+GxZTXABcGCmq3miyBredyIzXwD2RsR7S9GFwKNU+DjTHvq5ICJWl//n\nM/tc2ePcYaHH9S7goohYW3pOF5WyhRv2hMgyTbJ8Avgr4CfAl4fdniXcr39Mu6v3EPCj8vgE7bHP\ne4Any8/TSv2gvSLqJ8CPaa+wGPp+DLD/HwG+XZ6/C7gfGAf+BFhRyleW1+Nl+7uG3e5F7uvfB/aU\nY/2nwNqqH2fgPwKPAw8D/xNYUbXjDNxCe45jivZf8lcs5rgCv172fRz43GLb45nAklRTVRwCkiQd\nBwNAkmrKAJCkmjIAJKmmDABJqikDQJJqygCQpJoyACSppv4/2+AXYm3M+ToAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1045f4400>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Minimize the loss/maximize the likelihood/maximize the posterior assuming a flat prior\n",
"num_steps = 1000\n",
"feed_dict = {tf_design_matrix: design_matrix, tf_observations: observations}\n",
"\n",
"trace = []\n",
"for _ in range(num_steps):\n",
" _, loss = session.run([tf_train_op, tf_loss], feed_dict)\n",
" trace.append(np.sum(loss))\n",
" \n",
"plt.plot(trace)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.924755283622 -1.00163 0.100013 0.768688848809\n",
"1.12888989541 1.37404 0.111109 -2.20637732508\n",
"-1.12879126859 -1.14417 0.101273 0.151897947262\n",
"-0.724737622083 -0.827122 0.0955059 1.07201712824\n",
"0.623571208783 0.66956 0.0888993 -0.517310305816\n"
]
}
],
"source": [
"estimate, hessian = session.run([tf_coefficients, tf_hessian], feed_dict)\n",
"\n",
"for x, y, yerr in zip(coefficients, estimate, np.sqrt(np.diag(np.linalg.inv(hessian)))):\n",
" print(x, y, yerr, (x - y) / yerr)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment