Skip to content

Instantly share code, notes, and snippets.

@stijnvanhoey
Last active May 26, 2020 06:26
Show Gist options
  • Save stijnvanhoey/0b6a89c1ab12e4c9a4bff123cdf281b4 to your computer and use it in GitHub Desktop.
Save stijnvanhoey/0b6a89c1ab12e4c9a4bff123cdf281b4 to your computer and use it in GitHub Desktop.
SIR model implementation tryout
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy.integrate import solve_ivp\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## As simple and scipy-close as possible"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# User specifies...\n",
"\n",
"# ... state transitions\n",
"def integrate(t, y, parameters):\n",
" \"\"\"Basic SIR model\"\"\"\n",
" \n",
" # unpacking need to be done explicitly (how to control if correct?)\n",
" S, I, R = y \n",
" beta, gamma = parameters\n",
" \n",
" # Model equations\n",
" N = S + I + R\n",
" dS = -beta*S*I/N\n",
" dI = beta*S*I/N - gamma*I\n",
" dR = gamma*I\n",
"\n",
" return dS, dI, dR\n",
"\n",
"# ... time, parameters and initial conditions\n",
"time = [0, 250]\n",
"parameters = {\"beta\": 0.5, \"gamma\": 0.3} # same order as definition\n",
"initial_states = {\"S\": 7900000, \"I\": 10, \"R\": 0} # same order as definition\n",
"\n",
"# -> runs model with scipy directly.\n",
"output = solve_ivp(integrate, time, list(initial_states.values()), args=[list(parameters.values())])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f4d7674a2d0>]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"ax.plot(output[\"t\"], output[\"y\"][0, :])\n",
"ax.plot(output[\"t\"], output[\"y\"][1, :])\n",
"ax.plot(output[\"t\"], output[\"y\"][2, :])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split intregrate specification from model-user based specification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create middleware function `create_fun` which translates user defined integration function with vars/pars as individual arguments to a scipy compatible input"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def create_fun(initial_states, parameters):\n",
" \"\"\"Convert integrate statement to scipy-compatible function\"\"\"\n",
" \n",
" # OPTION TO ADD CHECKS on integrate function arguments\n",
" # versus the defined integrate function\n",
" # ... (not done for the example)\n",
"\n",
" def func(t, y, *pars):\n",
" #states = dict(zip(initial_states, y))\n",
" return integrate(t, *y, **parameters) \n",
"\n",
" return func"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# User specifies...\n",
"\n",
"# ... state transitions\n",
"def integrate(t, S, I, R, beta, gamma): # All variables and parameters... will be long list or arguments(!)\n",
" \"\"\"Basic SIR model\"\"\"\n",
" N = S + I + R\n",
" dS = -beta*S*I/N\n",
" dI = beta*S*I/N - gamma*I\n",
" dR = gamma*I\n",
"\n",
" return dS, dI, dR\n",
"\n",
"# ... time, parameters and initial conditions\n",
"time = [0, 150]\n",
"parameters = {\"beta\": 0.5, \"gamma\": 0.3}\n",
"initial_states = {\"S\": 7900000, \"I\": 10, \"R\": 0}\n",
"\n",
"# -> prepares the model function and runs with scipy\n",
"fun = create_fun(initial_states, parameters)\n",
"output = solve_ivp(fun, time, list(initial_states.values()), args=list(parameters.values()))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f4d76649790>]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"ax.plot(output[\"t\"], output[\"y\"][0, :])\n",
"ax.plot(output[\"t\"], output[\"y\"][1, :])\n",
"ax.plot(output[\"t\"], output[\"y\"][2, :])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Abstract away the boilerplate in Base class and define a new model as subclass"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class BaseModel:\n",
" \n",
" def __init__(self, states, parameters):\n",
" \"\"\"\"\"\"\n",
" self.parameters = parameters\n",
" self.initial_states = states\n",
" \n",
" # OPTION TO ADD CHECKS on incoming arguments on initialization\n",
" # versus the defined integrate variables and class attributes of the subclass\n",
" # ... @Joris -> we should add POC here as well, as it needs to compare the \n",
" # incoming dicts with the class attributes of the subclass?\n",
" # (and I rather not add a init to the subclass)\n",
"\n",
" def integrate(self):\n",
" \"\"\"to overwrite in subclasses\"\"\"\n",
" raise NotImplementedError \n",
" \n",
" def create_fun(self):\n",
" \"\"\"Convert integrate statement to scipy-compatible function\"\"\"\n",
"\n",
" def func(t, y, *pars):\n",
" #states = dict(zip(initial_states, y)) # @Joris, added value? \n",
" return self.integrate(t, *y, **self.parameters) \n",
"\n",
" return func\n",
" \n",
" def sim(self, time):\n",
" \"\"\"\"\"\" \n",
" fun = self.create_fun()\n",
" output = solve_ivp(fun, time, \n",
" list(self.initial_states.values()), \n",
" args=list(self.parameters.values()))\n",
" return output[\"t\"], self.array_to_variables(output[\"y\"]) # map back to variable names\n",
" \n",
" def array_to_variables(self, y):\n",
" \"\"\"Convert array (used by scipy) to dictionary (used by model API)\"\"\"\n",
" return dict(zip(self.state_names, y))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# User (model developer) specifies...\n",
"\n",
"# ... state transitions in a subclass\n",
"class SIR(BaseModel):\n",
"\n",
" # state variables and parameters\n",
" state_names = ['S', 'I', 'R']\n",
" parameter_names = ['beta', 'gamma']\n",
" \n",
" @staticmethod\n",
" def integrate(t, S, I, R, beta, gamma): # All variables and parameters... will be long list or arguments(!)\n",
" \"\"\"Basic SIR model\"\"\"\n",
" N = S + I + R\n",
" dS = -beta*S*I/N\n",
" dI = beta*S*I/N - gamma*I\n",
" dR = gamma*I\n",
" \n",
" return dS, dI, dR\n",
"\n",
"# ... parameters and initial conditions\n",
"parameters = {\"beta\": 0.5, \"gamma\": 0.3}\n",
"initial_states = {\"S\": 7900000, \"I\": 10, \"R\": 0}\n",
"\n",
"# -> user initiates the model\n",
"sir_model = SIR(initial_states, parameters)\n",
"\n",
"# -> user runs a simulation for a defined time period\n",
"time = [0, 150]\n",
"t, output = sir_model.sim(time)\n",
"\n",
"# -> user can do fit, mc,... using the model class instance `sir_model`"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f880e210110>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"ax.plot(t, output[\"S\"])\n",
"ax.plot(t, output[\"I\"])\n",
"ax.plot(t, output[\"R\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:COVID_MODEL]",
"language": "python",
"name": "conda-env-COVID_MODEL-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@JennaVergeynst
Copy link

Beautiful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment