Skip to content

Instantly share code, notes, and snippets.

@fredrick
Created February 22, 2017 18:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fredrick/b3fb22b92aeeafb8df4da9b8cb66ca0f to your computer and use it in GitHub Desktop.
Save fredrick/b3fb22b92aeeafb8df4da9b8cb66ca0f to your computer and use it in GitHub Desktop.
%matplotlib inline
import numpy as np
from lifelines import KaplanMeierFitter
from matplotlib import pyplot as plt
from pylab import rcParams
rcParams['figure.figsize'] = 20, 10
plt.style.use('ggplot')
def run_survival(data, group_by=None, groups=[]):
time_column = 'time'
observation_column = 'death'
ax = plt.subplot(111)
kmf = KaplanMeierFitter()
if group_by is None:
kmf.fit(data[time_column], data[observation_column])
print(kmf.survival_function_.head())
print('Median')
print(kmf.median_)
kmf.plot(ax=ax)
else:
kmf.fit(data[time_column], data[observation_column], label='baseline')
print(kmf.survival_function_.head())
print('Median')
print(kmf.median_)
kmf.plot(ax=ax)
grouped_data = data.groupby([group_by])
plt.title(group_by)
if len(groups) == 0:
groups = np.sort(data[group_by].unique())
for group in groups:
d = grouped_data.get_group(group)
print(group, len(d))
kmf.fit(d[time_column], d[observation_column], label=group)
print(kmf.survival_function_.head())
print('Median')
print(kmf.median_)
kmf.plot(ax=ax)
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment