Skip to content

Instantly share code, notes, and snippets.

@dgmp88
Created December 20, 2021 11:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dgmp88/7b37e3c7b7832abb4b162b3e768a4fa6 to your computer and use it in GitHub Desktop.
Save dgmp88/7b37e3c7b7832abb4b162b3e768a4fa6 to your computer and use it in GitHub Desktop.
import torch
from lifelines import CoxPHFitter
from lifelines.datasets import load_rossi
from torch import nn
DEVICE = torch.device("cpu")
MODEL = None
COX_MODEL = None
ROSSI = load_rossi()
def setup_models():
global MODEL, COX_MODEL
MODEL = nn.Sequential(nn.Conv2d(3, 5, 5))
MODEL = MODEL.eval().to(DEVICE)
# Setup the CPH model
COX_MODEL = CoxPHFitter()
COX_MODEL.fit(ROSSI, duration_col="week", event_col="arrest")
setup_models()
def process_image():
image = torch.rand(1, 3, 512, 512)
embed = MODEL(image)
return embed
def get_risk():
risk = 1 - COX_MODEL.predict_survival_function(ROSSI, times=[10])
return risk[0].values[0] * 100
def this_crashes():
get_risk()
process_image()
if __name__ == "__main__":
this_crashes()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment