Skip to content

Instantly share code, notes, and snippets.

@CamDavidsonPilon
Last active August 3, 2020 14:42
Show Gist options
  • Save CamDavidsonPilon/d6c2f6acea390d81a5c33150222ac70e to your computer and use it in GitHub Desktop.
Save CamDavidsonPilon/d6c2f6acea390d81a5c33150222ac70e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from lifelines.datasets import load_rossi\n",
"rossi = load_rossi()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>week</th>\n",
" <th>arrest</th>\n",
" <th>fin</th>\n",
" <th>age</th>\n",
" <th>race</th>\n",
" <th>wexp</th>\n",
" <th>mar</th>\n",
" <th>paro</th>\n",
" <th>prio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>27</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>17</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>18</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>19</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>52</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>52</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>19</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" week arrest fin age race wexp mar paro prio\n",
"0 20 1 0 27 1 0 0 1 3\n",
"1 17 1 0 18 1 0 0 1 8\n",
"2 25 1 0 19 0 1 0 1 13\n",
"3 52 0 1 23 1 1 1 1 1\n",
"4 52 0 0 19 0 1 0 1 3"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rossi.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# let's b-spline age\n",
"cph = CoxPHFitter().fit(rossi, \"week\", \"arrest\", formula=\"fin + bs(age, df=4) + wexp + mar + paro + prio\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>week</th>\n",
" <th>arrest</th>\n",
" <th>fin</th>\n",
" <th>age</th>\n",
" <th>race</th>\n",
" <th>wexp</th>\n",
" <th>mar</th>\n",
" <th>paro</th>\n",
" <th>prio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>52.0</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>17.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>52.0</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>17.551020</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>52.0</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>18.102041</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>52.0</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>18.653061</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>52.0</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>19.204082</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" week arrest fin age race wexp mar paro prio\n",
"0 52.0 0.0 0.5 17.000000 1.0 1.0 0.0 1.0 2.0\n",
"1 52.0 0.0 0.5 17.551020 1.0 1.0 0.0 1.0 2.0\n",
"2 52.0 0.0 0.5 18.102041 1.0 1.0 0.0 1.0 2.0\n",
"3 52.0 0.0 0.5 18.653061 1.0 1.0 0.0 1.0 2.0\n",
"4 52.0 0.0 0.5 19.204082 1.0 1.0 0.0 1.0 2.0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# now we need to \"extend\" our data to plot it\n",
"# we'll plot age over it's observed range\n",
"age_range = np.linspace(rossi['age'].min(), rossi['age'].max(), 50)\n",
"\n",
"# need to create a matrix of variables at their means, _except_ for age. \n",
"x_bar = cph._central_values\n",
"df_varying_age = pd.concat([x_bar] * 50).reset_index(drop=True)\n",
"df_varying_age['age'] = age_range\n",
"\n",
"df_varying_age.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"cph.predict_log_partial_hazard(df_varying_age).plot()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# compare to _not_ bspline-ing:\n",
"cph = CoxPHFitter().fit(rossi, \"week\", \"arrest\", formula=\"fin + age + wexp + mar + paro + prio\")\n",
"\n",
"age_range = np.linspace(rossi['age'].min(), rossi['age'].max(), 50)\n",
"\n",
"# need to create a matrix of variables at their means, _except_ for age. \n",
"x_bar = cph._central_values\n",
"df_varying_age = pd.concat([x_bar] * 50).reset_index(drop=True)\n",
"df_varying_age['age'] = age_range\n",
"\n",
"cph.predict_log_partial_hazard(df_varying_age).plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment