Last active
June 9, 2021 08:49
-
-
Save jcrudy/10481743 to your computer and use it in GitHub Desktop.
Sampling survival times from arbitrary hazard functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Produce some simulated survival data from a weird hazard function | |
import numpy | |
from samplers import HazardSampler | |
# Set a random seed and sample size | |
numpy.random.seed(1) | |
m = 1000 | |
# Use this totally crazy hazard function | |
hazard = lambda t: numpy.exp(numpy.sin(t) - 2.0) | |
# Sample failure times from the hazard function | |
sampler = HazardSampler(hazard) | |
failure_times = numpy.array([sampler.draw() for _ in range(m)]) | |
# Apply some non-informative right censoring, just to demonstrate how it's done | |
censor_times = numpy.random.uniform(0.0, 25.0, size=m) | |
y = numpy.minimum(failure_times, censor_times) | |
c = 1.0 * (censor_times > failure_times) | |
# Make some plots of the simulated data | |
from matplotlib import pyplot | |
from statsmodels.distributions import ECDF | |
# Plot a histogram of failure times from this hazard function | |
pyplot.hist(failure_times, bins=50) | |
pyplot.title('Uncensored Failure Times') | |
pyplot.savefig('uncensored_hist.png') | |
pyplot.show() | |
# Plot a histogram of censored failure times from this hazard function | |
pyplot.hist(y, bins=50) | |
pyplot.title('Non-informatively Right Censored Failure Times') | |
pyplot.savefig('censored_hist.png') | |
pyplot.show() | |
# Plot the empirical survival function (based on the censored sample) against the actual survival function | |
t = numpy.arange(0,20.0,.1) | |
S = numpy.array([sampler.survival_function(t[i]) for i in range(len(t))]) | |
S_hat = 1.0 - ECDF(failure_times)(t) | |
pyplot.figure() | |
pyplot.title('Survival Function Comparison') | |
pyplot.plot(t, S, 'r', lw=3, label='True survival function') | |
pyplot.plot(t, S_hat, 'b--', lw=3, label='Sampled survival function (1 - ECDF)') | |
pyplot.legend() | |
pyplot.xlabel('Time') | |
pyplot.ylabel('Proportion Still Alive') | |
pyplot.savefig('survival_comp.png') | |
pyplot.show() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import scipy.integrate | |
class HazardSampler(object): | |
def __init__(self, hazard, start=0.0, step=None): | |
self.hazard = hazard | |
if step is None: | |
h0 = hazard(0.0) | |
if h0 > 0: | |
step = 2.0 / hazard(0.0) | |
else: | |
# Reasonable default. Not efficient in some cases. | |
step = 200.0 / scipy.integrate.quad(hazard, 0.0, 100.0) | |
self.cumulative_hazard = CumulativeHazard(hazard) | |
self.survival_function = SurvivalFunction(self.cumulative_hazard) | |
self.cdf = Cdf(self.survival_function) | |
self.inverse_cdf = InverseCdf(self.cdf, start=start, step=step, lower=0.0) | |
self.sampler = InversionTransformSampler(self.inverse_cdf) | |
def draw(self): | |
return self.sampler.draw() | |
class InversionTransformSampler(object): | |
def __init__(self, inverse_cdf): | |
self.inverse_cdf = inverse_cdf | |
def draw(self): | |
u = numpy.random.uniform(0,1) | |
return self.inverse_cdf(u) | |
class CumulativeHazard(object): | |
def __init__(self, hazard): | |
self.hazard = hazard | |
def __call__(self, t): | |
return scipy.integrate.quad(self.hazard, 0.0, t)[0] | |
class SurvivalFunction(object): | |
def __init__(self, cumulative_hazard): | |
self.cumulative_hazard = cumulative_hazard | |
def __call__(self, t): | |
return numpy.exp(-self.cumulative_hazard(t)) | |
class Cdf(object): | |
def __init__(self, survival_function): | |
self.survival_function = survival_function | |
def __call__(self, t): | |
return 1.0 - self.survival_function(t) | |
class InverseCdf(object): | |
def __init__(self, cdf, start, step, precision=1e-8, lower=float('-inf'), | |
upper=float('inf')): | |
self.cdf = cdf | |
self.precision = precision | |
self.start = start | |
self.step = step | |
self.lower = lower | |
self.upper = upper | |
def __call__(self, p): | |
last_diff = None | |
step = self.step | |
current = self.start | |
while True: | |
value = self.cdf(current) | |
diff = value - p | |
if abs(diff) < self.precision: | |
break | |
elif diff < 0: | |
current = min(current + step, self.upper) | |
if last_diff is not None and last_diff > 0: | |
step *= 0.5 | |
last_diff = diff | |
else: | |
current = max(current - step, self.lower) | |
if last_diff is not None and last_diff < 0: | |
step *= 0.5 | |
last_diff = diff | |
return current |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from .samplers import InverseCdf, HazardSampler | |
from nose.tools import assert_almost_equal | |
import scipy.stats | |
import numpy | |
from matplotlib import pyplot | |
from statsmodels.distributions import ECDF | |
numpy.random.seed(1) | |
class TestInverseCdf(object): | |
test_size = 100 | |
def test_normal_cdf(self): | |
cdf = scipy.stats.norm.cdf | |
q = scipy.stats.norm.ppf | |
inverter = InverseCdf(cdf, 0.0, 1.0) | |
for _ in range(self.test_size): | |
u = numpy.random.uniform() | |
q_ = inverter(u) | |
assert_almost_equal(q(u), q_, 5) | |
class TestHazardSampler(object): | |
test_size = 100 | |
def test_constant_hazard(self): | |
''' | |
Test against a constant hazard function, which should give an | |
exponential distribution. | |
''' | |
hazard = lambda t: 1.0 | |
sampler = HazardSampler(hazard) | |
for _ in range(self.test_size): | |
u = numpy.random.uniform() | |
t = scipy.stats.expon.ppf(u) | |
s = 1.0 - u | |
assert_almost_equal(sampler.cumulative_hazard(t), t, 5) | |
assert_almost_equal(sampler.survival_function(t), s, 5) | |
assert_almost_equal(sampler.cdf(t), u, 5) | |
assert_almost_equal(sampler.inverse_cdf(u), t, 5) | |
sample = [] | |
for __ in range(10000): | |
sample.append(sampler.draw()) | |
sample = numpy.array(sample) | |
cdf = ECDF(sample) | |
points = numpy.arange(0.0, 10.0, .1) | |
points_est = cdf(points) | |
points_expon = scipy.stats.expon.cdf(points) | |
numpy.testing.assert_almost_equal(points_est, points_expon, 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment