PMR tutorial 1 notes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:eff41c41a5ebd1edc41f0a9e68a94ac18d5b417ce2cfe2baf22d49c0f7699eb9" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "heading", | |
"level": 1, | |
"metadata": {}, | |
"source": [ | |
"Question 1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"First we define the factors that make up the distribution:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"phi_1 = lambda x,y: x+y\n", | |
"phi_2 = lambda y,z:(y+1)*(z+1)\n", | |
"def P_r_given_zx(r,x,z):\n", | |
" if r==1:\n", | |
" return 0.5 + 0.2*(x-z)\n", | |
" elif r==0:\n", | |
" return 1 - (0.5 + 0.2*(x-z))" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Then, since we're on the computer we can just sum over all of these to get the answer, without bothering to rearrange for efficiency:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"# all are binary\n", | |
"Z = 0\n", | |
"for x in [0,1]:\n", | |
" for y in [0,1]:\n", | |
" for z in [0,1]:\n", | |
" for r in [0,1]:\n", | |
" Z += P_r_given_zx(r,x,z)*phi_1(x,y)*phi_2(y,z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"100000 loops, best of 3: 14.1 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 10 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print(Z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"21.0\n" | |
] | |
} | |
], | |
"prompt_number": 11 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's see if we can speed this up by passing messages and caching.\n", | |
"First, we can eliminate $P(r|z,x)$:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"# all are binary\n", | |
"Z = 0\n", | |
"for x in [0,1]:\n", | |
" for y in [0,1]:\n", | |
" for z in [0,1]:\n", | |
" Z += phi_1(x,y)*phi_2(y,z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"100000 loops, best of 3: 3.93 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 12 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can see that the answer is still the same, as expected:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print(Z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"21.0\n" | |
] | |
} | |
], | |
"prompt_number": 13 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Then we can start passing messages, defining a new function that caches the sum values of $\\sum_{z}\\phi_{2}(y,z)$. Expressing this function as a dictionary and indexing:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"# all are binary\n", | |
"Z = 0\n", | |
"gamma_z = {}\n", | |
"# looking at each case of y\n", | |
"gamma_z[0] = phi_2(0,0) + phi_2(0,1)\n", | |
"gamma_z[1] = phi_2(1,0) + phi_2(1,1)\n", | |
"for x in [0,1]:\n", | |
" for y in [0,1]:\n", | |
" Z += phi_1(x,y)*gamma_z[y]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"100000 loops, best of 3: 2.1 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 15 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print(Z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"21.0\n" | |
] | |
} | |
], | |
"prompt_number": 16 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"These final messages won't actual speed it up any more, because we can't really factorize anything else out:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"# all are binary\n", | |
"Z = 0\n", | |
"gamma_z = {}\n", | |
"# looking at each case of y\n", | |
"gamma_z[0] = phi_2(0,0) + phi_2(0,1)\n", | |
"gamma_z[1] = phi_2(1,0) + phi_2(1,1)\n", | |
"# each case of x\n", | |
"gamma_y = {}\n", | |
"for x in [0,1]:\n", | |
" gamma_y[x] = phi_1(x,0)*gamma_z[0] + phi_1(x,1)*gamma_z[1]\n", | |
"for x in [0,1]:\n", | |
" Z += gamma_y[x]*gamma_z[y]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"100000 loops, best of 3: 2.31 \u00b5s per loop\n" | |
] | |
} | |
], | |
"prompt_number": 23 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In fact, it's slower, probably due to implementation reasons." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print(Z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"21.0\n" | |
] | |
} | |
], | |
"prompt_number": 20 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"gamma_z = {}\n", | |
"# looking at each case of y\n", | |
"gamma_z[0] = phi_2(0,0) + phi_2(0,1)\n", | |
"gamma_z[1] = phi_2(1,0) + phi_2(1,1)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 25 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"gamma_z" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 26, | |
"text": [ | |
"{0: 3, 1: 6}" | |
] | |
} | |
], | |
"prompt_number": 26 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# each case of x\n", | |
"gamma_y = {}\n", | |
"for x in [0,1]:\n", | |
" gamma_y[x] = phi_1(x,0)*gamma_z[0] + phi_1(x,1)*gamma_z[1]\n", | |
"gamma_y" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 28, | |
"text": [ | |
"{0: 6, 1: 15}" | |
] | |
} | |
], | |
"prompt_number": 28 | |
}, | |
{ | |
"cell_type": "heading", | |
"level": 1, | |
"metadata": {}, | |
"source": [ | |
"Question 4" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So computing this just involves more sums:\n", | |
"\n", | |
"$$ P(x|z=1) = \\frac{\\sum_{y,r} \\frac{1}{Z} P(r|x,z=1) \\phi_{1}(x,y) \\phi_{2}(y,z=1) }{ \\sum_{x,y,r} \\frac{1}{Z} P(r|x,z=1) \\phi_{1}(x,y) \\phi_{2}(y,z=1) } $$\n", | |
"\n", | |
"Cancelling the Zs:\n", | |
"\n", | |
"$$ P(x|z=1) = \\frac{\\sum_{y,r} P(r|x,z=1) \\phi_{1}(x,y) \\phi_{2}(y,z=1) }{ \\sum_{x,y,r} P(r|x,z=1) \\phi_{1}(x,y) \\phi_{2}(y,z=1) } $$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"numerator = 0\n", | |
"denominator = 0\n", | |
"x = 0\n", | |
"z = 1\n", | |
"for y in [0,1]:\n", | |
" for r in [0,1]:\n", | |
" numerator += P_r_given_zx(r,x,z)*phi_1(x,y)*phi_2(y,z)\n", | |
" for _x in [0,1]:\n", | |
" denominator += P_r_given_zx(r,_x,z)*phi_1(_x,y)*phi_2(y,z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 34 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"P_x_given_z = numerator/denominator\n", | |
"print(P_x_given_z,1.0-P_x_given_z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"0.2857142857142857 0.7142857142857143\n" | |
] | |
} | |
], | |
"prompt_number": 35 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Which is correct (from solutions).\n", | |
"\n", | |
"Next question:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"numerator = 0\n", | |
"denominator = 0\n", | |
"y=1\n", | |
"r = 0\n", | |
"for x in [0,1]:\n", | |
" for z in [0,1]:\n", | |
" numerator += P_r_given_zx(r,x,z)*phi_1(x,y)*phi_2(y,z)\n", | |
" for _y in [0,1]:\n", | |
" denominator += P_r_given_zx(r,x,z)*phi_1(x,_y)*phi_2(_y,z)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 37 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"P_y_given_r = numerator/denominator\n", | |
"print(P_y_given_r,1.0-P_y_given_r)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"0.8737864077669902 0.12621359223300976\n" | |
] | |
} | |
], | |
"prompt_number": 38 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"((y+1)*(3.2*y + 1.3))/10.3" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 39, | |
"text": [ | |
"0.8737864077669902" | |
] | |
} | |
], | |
"prompt_number": 39 | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment