Created
January 27, 2014 18:21
-
-
Save aflaxman/8654408 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"!date" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"Mon Jan 27 10:19:59 PST 2014\r\n" | |
] | |
} | |
], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%matplotlib inline\n", | |
"import pymc, pymc as pm, numpy as np" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
" From: pymc@googlegroups.com [mailto:pymc@googlegroups.com] On Behalf Of Quentin CAUDRON\n", | |
" Sent: Monday, January 20, 2014 4:16 PM\n", | |
" To: pymc@googlegroups.com\n", | |
" Subject: [pymc] SIR model with PyMC3 - dynamical / compartmental\n", | |
"\n", | |
"Hi,\n", | |
"\n", | |
"I'm trying to put together some parameter inference for a compartmental epidemiological model. The simplest form looks like this :\n", | |
"\n", | |
"\\begin{align}\n", | |
"dS/dt &= - r * g * S * I\\\\\n", | |
"dI/dt &= g * I ( r * S - 1 )\n", | |
"\\end{align}\n", | |
"\n", | |
"I'd like to use PyMC3 to infer the parameters r and g. \n", | |
"\n", | |
"odeint won't accept pymc variables, so that won't work. I tried using manual integration ( just a basic forward Euler works here ), but trying to do arithmetic on pymc variables in a loop gives me ValueError: setting an array element with a sequence. :\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"t = np.linspace(0,1,20)\n", | |
"solution = np.zeros_like(t)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"with pymc.Model() as model :\n", | |
" r0 = pymc.Normal(\"r0\", 15, sd=10)\n", | |
" gamma = pymc.Uniform(\"gamma\", 0.3, 1.)\n", | |
" \n", | |
" # Use forward Euler to solve\n", | |
" dt = t[1] - t[0]\n", | |
" S = np.zeros_like(solution)\n", | |
" I = np.zeros_like(solution)\n", | |
" S[0] = 0.99\n", | |
" I[0] = 0.01\n", | |
"\n", | |
" for i in range(len(t[1:])) :\n", | |
" S[i] = S[i-1] + dt * (-r0 * gamma * S[i-1] * I[i-1])\n", | |
" I[i] = I[i-1] + dt * ( r0 * gamma * S[i-1] * I[i-1] - gamma * I[i-1])\n" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "ValueError", | |
"evalue": "setting an array element with a sequence.", | |
"output_type": "pyerr", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-5-c6feb0e9eed4>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m \u001b[0mS\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mdt\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mr0\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mgamma\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mI\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 14\u001b[0m \u001b[0mI\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mI\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mdt\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;33m(\u001b[0m \u001b[0mr0\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mgamma\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mI\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mgamma\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mI\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mValueError\u001b[0m: setting an array element with a sequence." | |
] | |
} | |
], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The problem is `S` is zeros too much like `t`, specifically an array of floats." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"S.dtype" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 6, | |
"text": [ | |
"dtype('float64')" | |
] | |
} | |
], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Forcing `S.dtype` to be `object` sorts that out." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"with pm.Model() as model :\n", | |
" r0 = pm.Normal(\"r0\", 15, sd=10)\n", | |
" gamma = pm.Uniform(\"gamma\", 0.3, 1.)\n", | |
" \n", | |
" # Use forward Euler to solve\n", | |
" dt = t[1] - t[0]\n", | |
" S = np.zeros_like(solution, dtype=object) # FIX #1\n", | |
" I = np.zeros_like(solution, dtype=object) # FIX #2\n", | |
" S[0] = 0.99\n", | |
" I[0] = 0.01\n", | |
"\n", | |
" for i in range(len(t[1:])) :\n", | |
" S[i] = S[i-1] + dt * (-r0 * gamma * S[i-1] * I[i-1])\n", | |
" I[i] = I[i-1] + dt * ( r0 * gamma * S[i-1] * I[i-1] - gamma * I[i-1])\n" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Success!" | |
] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment