Last active
December 23, 2021 07:54
-
-
Save alvinthai/ac2505970c96b9888abdfb1577c5c9c5 to your computer and use it in GitHub Desktop.
survival_random_forest.r
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# data laod --------------------------------------------------------------- | |
# download data set | |
actg320_colnames <- c('id','time','censor','time_d','censor_d','treatment','treatment_group', | |
'strat2','sex','raceth','ivdrug','hemophil','karnof','cd4','priorzdv','age') | |
actg320 <- read.table('https://raw.githubusercontent.com/WinVector/QSurvival/master/vignettes/AIDSdata/actg320.dat', col.names = actg320_colnames) | |
dim(actg320) | |
head(actg320) | |
# we're removing time_d and censor_2 as it has a rarer outcome balance | |
actg320 <- actg320[,c('time', 'censor', 'treatment','treatment_group', | |
'strat2','sex','raceth','ivdrug','hemophil','karnof','cd4','priorzdv','age')] | |
# install.packages('ranger') | |
library(ranger) | |
# install.packages('survival') | |
library(survival) | |
survival_formula <- formula(paste('Surv(', 'time', ',', 'censor', ') ~ ','treatment+treatment_group', | |
'+strat2+sex+raceth+ivdrug+hemophil+karnof+cd4+priorzdv+age')) | |
survival_formula | |
survival_model <- ranger(survival_formula, | |
data = actg320, | |
seed = 1234, | |
importance = 'permutation', | |
mtry = 2, | |
verbose = TRUE, | |
num.trees = 50, | |
write.forest=TRUE) | |
# print out coefficients | |
sort(survival_model$variable.importance) | |
plot(survival_model$unique.death.times, survival_model$survival[1,], type='l', col='orange', ylim=c(0.4,1)) | |
lines(survival_model$unique.death.times, survival_model$survival[56,], col='blue') | |
actg320[1,] | |
actg320[56,] | |
plot(survival_model$unique.death.times, survival_model$survival[1,], type='l', col='orange', ylim=c(0.4,1)) | |
for (x in c(2:100)) { | |
lines(survival_model$unique.death.times, survival_model$survival[x,], col='red') | |
} | |
set.seed(1234) | |
random_splits <- runif(nrow(actg320)) | |
train_df_official <- actg320[random_splits < .5,] | |
dim(train_df_official) | |
validate_df_official <- actg320[random_splits >= .5,] | |
dim(validate_df_official) | |
period_choice <- 82 # 103 | |
table(train_df_official$time) | |
# classification data set | |
train_df_classificaiton <- train_df_official | |
train_df_classificaiton$ReachedEvent <- ifelse((train_df_classificaiton$censor==1 & | |
train_df_classificaiton$time<=period_choice), 1, 0) | |
summary(train_df_classificaiton$ReachedEvent) | |
validate_df_classification <- validate_df_official | |
validate_df_classification$ReachedEvent <- ifelse((validate_df_classification$censor==1 & | |
validate_df_classification$time<=period_choice), 1, 0) | |
summary(validate_df_classification$ReachedEvent) | |
feature_names <- setdiff(names(train_df_classificaiton), c('ReachedEvent', 'time', 'censor')) | |
# isntall.packages('gbm') | |
library(gbm) | |
classification_formula <- formula(paste('ReachedEvent ~ ','treatment+treatment_group', | |
'+strat2+sex+raceth+ivdrug+hemophil+karnof+cd4+priorzdv+age')) | |
set.seed(1234) | |
gbm_model = gbm(classification_formula, | |
data = train_df_classificaiton, | |
distribution='bernoulli', | |
n.trees=500, | |
interaction.depth=3, | |
shrinkage=0.01, | |
bag.fraction=0.5, | |
keep.data=FALSE, | |
cv.folds=5) | |
nTrees <- gbm.perf(gbm_model) | |
validate_predictions <- predict(gbm_model, newdata=validate_df_classification[,feature_names], type="response", n.trees=nTrees) | |
# install.packages('pROC') | |
library(pROC) | |
roc(response=validate_df_classification$ReachedEvent, predictor=validate_predictions) | |
survival_model <- ranger(survival_formula, | |
data = train_df_official, | |
seed=1234, | |
verbose = TRUE, | |
num.trees = 50, | |
mtry = 2, | |
write.forest=TRUE ) | |
survival_model$unique.death.times | |
suvival_predictions <- predict( survival_model, validate_df_official[, c('treatment','treatment_group', | |
'strat2','sex','raceth','ivdrug', | |
'hemophil','karnof','cd4', | |
'priorzdv','age')]) | |
roc(response=validate_df_classification$ReachedEvent, predictor=1 - suvival_predictions$survival[,which(suvival_predictions$unique.death.times==period_choice)]) | |
# blend both together ------------------------------------------------------- | |
roc(predictor = (validate_predictions + (1 - suvival_predictions$survival[,which(suvival_predictions$unique.death.times==period_choice)]))/2, | |
response = validate_df_classification$ReachedEvent) | |
# split training into two datasets | |
set.seed(1234) | |
random_splits <- runif(nrow(train_df_official)) | |
train_1 <- train_df_official[random_splits < .5,] | |
dim(train_1) | |
train_2 <- train_df_official[random_splits >= .5,] | |
dim(train_2) | |
# split testing set in two | |
set.seed(1234) | |
random_splits <- runif(nrow(validate_df_official)) | |
test_1 <- validate_df_official[random_splits < .5,] | |
dim(test_1) | |
test_2 <- validate_df_official[random_splits >= .5,] | |
dim(test_2) | |
surv_1 <- ranger(survival_formula, | |
data = train_1, | |
verbose = TRUE, | |
seed=1234, | |
num.trees = 50, | |
mtry = 2, | |
write.forest=TRUE ) | |
surv_1$unique.death.times | |
preds <- predict( surv_1, rbind(train_2[,feature_names], test_2[,feature_names])) | |
preds_1 <- data.frame(preds$survival) | |
surv_2 <- ranger(survival_formula, | |
data = train_2, | |
verbose = TRUE, | |
seed=1234, | |
num.trees = 50, | |
mtry = 2, | |
write.forest=TRUE ) | |
surv_2$unique.death.times | |
preds <- predict( surv_2, rbind(train_1[,feature_names], test_1[,feature_names])) | |
preds_2 <- data.frame(preds$survival) | |
# NOTE: can't use period_choice here as second data set doesn't have that period | |
surv_1$unique.death.times | |
train_2_ensemble <- cbind(train_2, preds_1[1:nrow(train_2),which(surv_1$unique.death.times == period_choice)]) | |
names(train_2_ensemble)[ncol(train_2_ensemble)] <- 'survival_probablities' | |
dim(train_2_ensemble) | |
names(train_2_ensemble) | |
test_2_ensemble <- cbind(test_2, preds_1[((nrow(train_2_ensemble)+1):nrow(preds_1)),which(surv_1$unique.death.times == period_choice)]) | |
names(test_2_ensemble)[ncol(test_2_ensemble)] <- 'survival_probablities' | |
surv_2$unique.death.times | |
train_1_ensemble <- cbind(train_1, preds_2[1:nrow(train_1),which(surv_2$unique.death.times == period_choice)]) | |
names(train_1_ensemble)[ncol(train_1_ensemble)] <- 'survival_probablities' | |
test_1_ensemble <- cbind(test_1, preds_2[((nrow(train_1_ensemble)+1):nrow(preds_2)),which(surv_2$unique.death.times == period_choice)]) | |
names(test_1_ensemble)[ncol(test_1_ensemble)] <- 'survival_probablities' | |
# finally bring them both back together | |
train_df_final <- rbind(train_1_ensemble, train_2_ensemble) | |
validate_df_final <- rbind(test_1_ensemble, test_2_ensemble) | |
# enjoy fruits of our labor | |
train_df_final$ReachedEvent <- ifelse((train_df_final$censor==1 & | |
train_df_final$time <= period_choice), 1, 0) | |
summary(train_df_final$ReachedEvent) | |
validate_df_final$ReachedEvent <- ifelse((validate_df_final$censor==1 & | |
validate_df_final$time<= period_choice), 1, 0) | |
feature_names <- setdiff(names(train_df_final), c('ReachedEvent', 'time', 'censor')) | |
classification_formula <- formula(paste('ReachedEvent ~ ','treatment+treatment_group', | |
'+strat2+sex+raceth+ivdrug+hemophil+karnof+cd4+priorzdv+age+survival_probablities')) | |
set.seed(1234) | |
gbm_model = gbm(classification_formula, | |
data = train_df_final, | |
distribution='bernoulli', | |
n.trees=500, | |
interaction.depth=1, | |
shrinkage=0.01, | |
bag.fraction=0.5, | |
keep.data=FALSE, | |
cv.folds=5) | |
nTrees <- gbm.perf(gbm_model) | |
validate_predictions <- predict(gbm_model, newdata=validate_df_final[,feature_names], type="response", n.trees=nTrees) | |
roc(response=validate_df_final$ReachedEvent, predictor=validate_predictions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment