Skip to content

Instantly share code, notes, and snippets.

@cjbayesian
Created January 8, 2020 15:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cjbayesian/1fd16e4c46798c7e0a32965b7e5216cf to your computer and use it in GitHub Desktop.
Save cjbayesian/1fd16e4c46798c7e0a32965b7e5216cf to your computer and use it in GitHub Desktop.
Plot kaplan-meier style survival curves with errorbars
import scipy as sp
def beta_errors(num, denom):
return sp.stats.beta.interval(0.95, num+1, denom-num+1)
def plot_km(df, threshold=0.5, max_days=365, y_text_shrink=1, ax=None):
days = range(max_days)
idb_above = df['Pred']>threshold
survival_series = df['survival_time_days']
labels = ['High risk','Low risk']
grps = [survival_series[idb_above].copy(),survival_series[~idb_above].copy()]
if ax is None:
fig, ax = plt.subplots(1,1)
for i, grp in enumerate(grps):
survival_mean = []
survival_numerator = []
survival_denominator = []
for day in days:
idb = grp > day
survival_mean.append(idb.mean())
survival_numerator.append(grp[idb].shape[0])
survival_denominator.append((idb.shape[0]))
ci = [beta_errors(num, denom) for num, denom in zip(survival_numerator,survival_denominator)]
lower = [interval[0] for interval in ci]
upper = [interval[1] for interval in ci]
label = labels[i]
proportions = " (n={}, {:.1%})".format(grp.shape[0], float(grp.shape[0])/df.shape[0])
ax.plot(days, survival_mean,'-',label=label+proportions)
ax.plot(days, lower,'--',color='grey')
ax.plot(days, upper,'--',color='grey')
#ax.set_xlim(0,max(days))
ax.set_xlim(0,max(days))
#xticks = ax.get_xticks()
xticks = [0,45,90,135,180]
ax.set_xticks(xticks)
if i == 0:
y_text = -0.15*y_text_shrink
ax.text(-xticks[1]/4,y_text,'High Risk n Survived',horizontalalignment='right')
else:
y_text = -0.2*y_text_shrink
ax.text(-xticks[1]/4,y_text,'Low Risk n Survived',horizontalalignment='right')
for tick in xticks:
txt = " {}".format(survival_numerator[int(tick)])
ax.text(tick,y_text,txt,horizontalalignment='center')
ax.legend(loc=0)
#ax.set_title('Risk threshold: '+str(thresh))
ax.set_xlim(0,max(days))
ax.set_ylim(0,1)
ax.set_xlabel('Time (days)')
ax.set_ylabel('Survival Probability')
fig, axx = plt.subplots(1,1,figsize=(7,7))
thresh = 0.5
ttmp = preds.sort_values(['PAT_ID','APPT_TIME'])
ttmp = ttmp.drop_duplicates('PAT_ID', keep='last')
plot_km(ttmp, threshold=thresh, max_days=int(181), ax=axx)
axx.grid(True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment