Skip to content

Instantly share code, notes, and snippets.

@hminle
Last active January 5, 2024 17:31
Show Gist options
  • Save hminle/bc1a3dea64e42f8dc90c2cd617f71f6f to your computer and use it in GitHub Desktop.
Save hminle/bc1a3dea64e42f8dc90c2cd617f71f6f to your computer and use it in GitHub Desktop.
DeSurv
import torch
import numpy as np
import torch.nn as nn
from nfg.nfg_torch import *
from nfg.nfg_api import NeuralFineGray
class DeSurv(NeuralFineGray):
def _gen_torch_model(self, inputdim, optimizer, risks):
self.loss = losses.total_loss
model = DeSurvTorch(inputdim, **self.params,
risks = risks,
optimizer = optimizer).double()
if self.cuda > 0:
model = model.cuda()
return model
def predict_survival(self, x, t, risk = None):
x = self._preprocess_test_data(x)
if not isinstance(t, list):
t = [t]
if self.fitted:
scores = []
for t_ in t:
t_ = torch.DoubleTensor([t_] * len(x)).to(x.device)
pred, _, _ = self.torch_model(x, t_)
if risk is None:
scores.append(1 - pred.sum(1).unsqueeze(1).detach().cpu().numpy())
else:
scores.append(1 - pred[:, int(risk) - 1].unsqueeze(1).detach().cpu().numpy())
return np.concatenate(scores, axis = 1)
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")
class CondODENet(nn.Module):
"""
Code extracted from https://github.com/djdanks/DeSurv
"""
def __init__(self, cov_dim, layers, output_dim,
act = "ReLU", n = 15):
super().__init__()
self.output_dim = output_dim
self.f = nn.Sequential(*create_representation(cov_dim + 1, layers + [output_dim], act, last = nn.Softplus()))
self.n = n
u_n, w_n = np.polynomial.legendre.leggauss(n)
self.u_n = nn.Parameter(torch.tensor(u_n, dtype = torch.float32)[None, :], requires_grad = False)
self.w_n = nn.Parameter(torch.tensor(w_n, dtype = torch.float32)[None, :], requires_grad = False)
def forward(self, x, horizon):
tau = torch.matmul(horizon.unsqueeze(-1) / 2., 1 + self.u_n) # N x n (+ 1 to push integral in 0 2 and /2 to push in 0 - t)
tau_ = torch.flatten(tau).unsqueeze(-1) # Nn x 1. Think of as N n-dim vectors stacked on top of each other
reppedx = torch.repeat_interleave(x, self.n, dim = 0)
taux = torch.cat((tau_, reppedx), 1) # Nn x (d+1)
f_n = self.f(taux).reshape((len(x), self.n, self.output_dim)) # N x n x d_out
pred = horizon.unsqueeze(-1) / 2. * ((self.w_n[:, :, None] * f_n).sum(dim = 1))
return torch.tanh(pred)
class DeSurvTorch(nn.Module):
def __init__(self, inputdim, layers = [100, 100, 100], act = 'ReLU', layers_surv = [100],
risks = 1, optimizer = "Adam", n = 15):
super().__init__()
self.input_dim = inputdim
self.risks = risks # Competing risks
self.optimizer = optimizer
self.balance = nn.Sequential(*create_representation(inputdim, layers + [risks], act, last = nn.Softmax(dim = 1))) # Balance between risks
self.odenet = CondODENet(inputdim, layers_surv, risks, act, n = n)
def forward(self, x, horizon):
balance = self.balance(x)
Fr = self.odenet(x, horizon)
return balance * Fr, balance, Fr
def total_loss(model, x, t, e, eps = 1e-10):
pred, balance, ode = model.forward(x, t)
# Likelihood error
error = - torch.log(1 - pred[e == 0].sum(dim = 1) + eps).sum()
for k in range(model.risks):
ids = (e == (k + 1))
derivative = model.odenet.f(torch.cat((t[ids].unsqueeze(1), x[ids]), 1))
error -= (torch.log(1 - ode[ids][:, k] ** 2 + eps)
+ torch.log(derivative[:, k] + eps)
+ torch.log(balance[ids][:, k] + eps)).sum()
return error / len(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment