Skip to content

Instantly share code, notes, and snippets.

@CamDavidsonPilon
Created October 4, 2018 22:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CamDavidsonPilon/a4f22ab47ac76ab0e72f11afc2fe9318 to your computer and use it in GitHub Desktop.
Save CamDavidsonPilon/a4f22ab47ac76ab0e72f11afc2fe9318 to your computer and use it in GitHub Desktop.
def concordance_index(event_times, predicted_scores, event_observed=None):
"""
Calculates the concordance index (C-index) between two series
of event times. The first is the real survival times from
the experimental data, and the other is the predicted survival
times from a model of some kind.
The concordance index is a value between 0 and 1 where,
0.5 is the expected result from random predictions,
1.0 is perfect concordance and,
0.0 is perfect anti-concordance (multiply predictions with -1 to get 1.0)
Score is usually 0.6-0.7 for survival models.
See:
Harrell FE, Lee KL, Mark DB. Multivariable prognostic models: issues in
developing models, evaluating assumptions and adequacy, and measuring and
reducing errors. Statistics in Medicine 1996;15(4):361-87.
Parameters:
event_times: a (n,) array of observed survival times.
predicted_scores: a (n,) array of predicted scores - these could be survival times, or hazards, etc.
See https://stats.stackexchange.com/questions/352183/use-median-survival-time-to-calculate-cph-c-statistic/352435#352435
event_observed: a (n,) array of censorship flags, 1 if observed,
0 if not. Default None assumes all observed.
Returns:
c-index: a value between 0 and 1.
"""
event_times = np.array(event_times, dtype=float)
predicted_scores = np.array(predicted_scores, dtype=float)
# Allow for (n, 1) or (1, n) arrays
if event_times.ndim == 2 and (event_times.shape[0] == 1 or
event_times.shape[1] == 1):
# Flatten array
event_times = event_times.ravel()
# Allow for (n, 1) or (1, n) arrays
if (predicted_scores.ndim == 2 and
(predicted_scores.shape[0] == 1 or
predicted_scores.shape[1] == 1)):
# Flatten array
predicted_scores = predicted_scores.ravel()
if event_times.shape != predicted_scores.shape:
raise ValueError("Event times and predictions must have the same shape")
if event_times.ndim != 1:
raise ValueError("Event times can only be 1-dimensional: (n,)")
if event_observed is None:
event_observed = np.ones(event_times.shape[0], dtype=float)
else:
if event_observed.shape != event_times.shape:
raise ValueError("Observed events must be 1-dimensional of same length as event times")
event_observed = np.array(event_observed, dtype=float).ravel()
return _concordance_index(event_times,
predicted_scores,
event_observed)
class _BTree(object):
"""A simple balanced binary order statistic tree to help compute the concordance.
When computing the concordance, we know all the values the tree will ever contain. That
condition simplifies this tree a lot. It means that instead of crazy AVL/red-black shenanigans
we can simply do the following:
- Store the final tree in flattened form in an array (so node i's children are 2i+1, 2i+2)
- Additionally, store the current size of each subtree in another array with the same indices
- To insert a value, just find its index, increment the size of the subtree at that index and
propagate
- To get the rank of an element, you add up a bunch of subtree counts
"""
def __init__(self, values):
"""
Parameters:
values: List of sorted (ascending), unique values that will be inserted.
"""
self._tree = self._treeify(values)
self._counts = np.zeros_like(self._tree, dtype=int)
@staticmethod
def _treeify(values):
"""Convert the np.ndarray `values` into a complete balanced tree.
Assumes `values` is sorted ascending. Returns a list `t` of the same length in which t[i] >
t[2i+1] and t[i] < t[2i+2] for all i."""
if len(values) == 1: # this case causes problems later
return values
tree = np.empty_like(values)
# Tree indices work as follows:
# 0 is the root
# 2n+1 is the left child of n
# 2n+2 is the right child of n
# So we now rearrange `values` into that format...
# The first step is to remove the bottom row of leaves, which might not be exactly full
last_full_row = int(np.log2(len(values) + 1) - 1)
len_ragged_row = len(values) - (2 ** (last_full_row + 1) - 1)
if len_ragged_row > 0:
bottom_row_ix = np.s_[:2 * len_ragged_row:2]
tree[-len_ragged_row:] = values[bottom_row_ix]
values = np.delete(values, bottom_row_ix)
# Now `values` is length 2**n - 1, so can be packed efficiently into a tree
# Last row of nodes is indices 0, 2, ..., 2**n - 2
# Second-last row is indices 1, 5, ..., 2**n - 3
# nth-last row is indices (2**n - 1)::(2**(n+1))
values_start = 0
values_space = 2
values_len = 2 ** last_full_row
while values_start < len(values):
tree[values_len - 1:2 * values_len - 1] = values[values_start::values_space]
values_start += int(values_space / 2)
values_space *= 2
values_len = int(values_len / 2)
return tree
def insert(self, value):
"""Insert an occurrence of `value` into the btree."""
i = 0
n = len(self._tree)
while i < n:
cur = self._tree[i]
self._counts[i] += 1
if value < cur:
i = 2 * i + 1
elif value > cur:
i = 2 * i + 2
else:
return
raise ValueError("Value %s not contained in tree."
"Also, the counts are now messed up." % value)
def __len__(self):
return self._counts[0]
def rank(self, value):
"""Returns the rank and count of the value in the btree."""
i = 0
n = len(self._tree)
rank = 0
count = 0
while i < n:
cur = self._tree[i]
if value < cur:
i = 2 * i + 1
continue
elif value > cur:
rank += self._counts[i]
# subtract off the right tree if exists
nexti = 2 * i + 2
if nexti < n:
rank -= self._counts[nexti]
i = nexti
continue
else:
return (rank, count)
else: # value == cur
count = self._counts[i]
lefti = 2 * i + 1
if lefti < n:
nleft = self._counts[lefti]
count -= nleft
rank += nleft
righti = lefti + 1
if righti < n:
count -= self._counts[righti]
return (rank, count)
return (rank, count)
def _concordance_index(event_times, predicted_event_times, event_observed):
"""Find the concordance index in n * log(n) time.
Assumes the data has been verified by lifelines.utils.concordance_index first.
"""
# Here's how this works.
#
# It would be pretty easy to do if we had no censored data and no ties. There, the basic idea
# would be to iterate over the cases in order of their true event time (from least to greatest),
# while keeping track of a pool of *predicted* event times for all cases previously seen (= all
# cases that we know should be ranked lower than the case we're looking at currently).
#
# If the pool has O(log n) insert and O(log n) RANK (i.e., "how many things in the pool have
# value less than x"), then the following algorithm is n log n:
#
# Sort the times and predictions by time, increasing
# n_pairs, n_correct := 0
# pool := {}
# for each prediction p:
# n_pairs += len(pool)
# n_correct += rank(pool, p)
# add p to pool
#
# There are three complications: tied ground truth values, tied predictions, and censored
# observations.
#
# - To handle tied true event times, we modify the inner loop to work in *batches* of observations
# p_1, ..., p_n whose true event times are tied, and then add them all to the pool
# simultaneously at the end.
#
# - To handle tied predictions, which should each count for 0.5, we switch to
# n_correct += min_rank(pool, p)
# n_tied += count(pool, p)
#
# - To handle censored observations, we handle each batch of tied, censored observations just
# after the batch of observations that died at the same time (since those censored observations
# are comparable all the observations that died at the same time or previously). However, we do
# NOT add them to the pool at the end, because they are NOT comparable with any observations
# that leave the study afterward--whether or not those observations get censored.
died_mask = event_observed.astype(bool)
# TODO: is event_times already sorted? That would be nice...
died_truth = event_times[died_mask]
ix = np.argsort(died_truth)
died_truth = died_truth[ix]
died_pred = predicted_event_times[died_mask][ix]
censored_truth = event_times[~died_mask]
ix = np.argsort(censored_truth)
censored_truth = censored_truth[ix]
censored_pred = predicted_event_times[~died_mask][ix]
censored_ix = 0
died_ix = 0
times_to_compare = _BTree(np.unique(died_pred))
num_pairs = np.int64(0)
num_correct = np.int64(0)
num_tied = np.int64(0)
def handle_pairs(truth, pred, first_ix):
"""
Handle all pairs that exited at the same time as truth[first_ix].
Returns:
(pairs, correct, tied, next_ix)
new_pairs: The number of new comparisons performed
new_correct: The number of comparisons correctly predicted
next_ix: The next index that needs to be handled
"""
next_ix = first_ix
while next_ix < len(truth) and truth[next_ix] == truth[first_ix]:
next_ix += 1
pairs = len(times_to_compare) * (next_ix - first_ix)
correct = np.int64(0)
tied = np.int64(0)
for i in range(first_ix, next_ix):
rank, count = times_to_compare.rank(pred[i])
correct += rank
tied += count
return (pairs, correct, tied, next_ix)
# we iterate through cases sorted by exit time:
# - First, all cases that died at time t0. We add these to the sortedlist of died times.
# - Then, all cases that were censored at time t0. We DON'T add these since they are NOT
# comparable to subsequent elements.
while True:
has_more_censored = censored_ix < len(censored_truth)
has_more_died = died_ix < len(died_truth)
# Should we look at some censored indices next, or died indices?
if has_more_censored and (not has_more_died or
died_truth[died_ix] > censored_truth[censored_ix]):
pairs, correct, tied, next_ix = handle_pairs(censored_truth, censored_pred, censored_ix)
censored_ix = next_ix
elif has_more_died and (not has_more_censored or
died_truth[died_ix] <= censored_truth[censored_ix]):
pairs, correct, tied, next_ix = handle_pairs(died_truth, died_pred, died_ix)
for pred in died_pred[died_ix:next_ix]:
times_to_compare.insert(pred)
died_ix = next_ix
else:
assert not (has_more_died or has_more_censored)
break
num_pairs += pairs
num_correct += correct
num_tied += tied
if num_pairs == 0:
raise ZeroDivisionError("No admissable pairs in the dataset.")
return (num_correct + num_tied / 2) / num_pairs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment