Created January 10, 2024 11:48
Reprex for XGBoost producing Inf in certain cases for survival models
if (!requireNamespace("mlr3proba")) remotes::install_github("mlr-org/mlr3proba")
#> Loading required namespace: mlr3proba

xdf = mlr3proba::grace
xdf[, c("id")] = NULL
colnames(xdf)[1:2] = c("time", "status")

rows = c(1L, 4L, 6L, 9L, 10L, 12L, 16L, 17L, 18L, 19L, 20L, 22L, 25L,
         26L, 27L, 28L, 30L, 33L, 36L, 37L, 38L, 39L, 40L, 43L, 45L, 49L,
         50L, 51L, 52L, 53L, 54L, 55L, 58L, 60L, 61L, 62L, 63L, 64L, 67L,
         68L, 71L, 73L, 74L, 75L, 76L, 77L, 79L, 80L, 81L, 83L, 85L, 86L,
         88L, 92L, 93L, 94L, 95L, 97L, 98L, 99L, 101L, 103L, 105L, 108L,
         109L, 110L, 111L, 112L, 114L, 115L, 117L, 118L, 119L, 121L, 122L,
         125L, 126L, 130L, 132L, 133L, 134L, 135L, 138L, 139L, 141L, 144L,
         145L, 148L, 149L, 150L, 152L, 154L, 157L, 159L, 160L, 162L, 163L,
         166L, 167L, 168L, 169L, 170L, 171L, 173L, 174L, 176L, 177L, 178L,
         180L, 181L, 182L, 183L, 185L, 187L, 188L, 190L, 191L, 192L, 193L,
         194L, 195L, 196L, 197L, 201L, 202L, 203L, 204L, 205L, 206L, 208L,
         209L, 210L, 211L, 213L, 214L, 215L, 216L, 218L, 221L, 224L, 225L,
         226L, 227L, 229L, 230L, 231L, 236L, 237L, 239L, 240L, 241L, 246L,
         247L, 249L, 250L, 251L, 253L, 254L, 255L, 256L, 258L, 259L, 260L,
         261L, 262L, 263L, 265L, 266L, 267L, 268L, 269L, 270L, 271L, 272L,
         273L, 274L, 275L, 279L, 280L, 281L, 282L, 283L, 285L, 286L, 289L,
         294L, 295L, 296L, 297L, 299L, 300L, 302L, 304L, 305L, 306L, 307L,
         309L, 310L, 312L, 313L, 314L, 315L, 316L, 319L, 322L, 326L, 327L,
         328L, 332L, 333L, 334L, 335L, 337L, 338L, 340L, 341L, 342L, 345L,
         346L, 349L, 350L, 351L, 352L, 353L, 354L, 355L, 356L, 357L, 360L,
         362L, 363L, 364L, 365L, 366L, 367L, 369L, 371L, 372L, 373L, 374L,
         376L, 377L, 378L, 380L, 381L, 382L, 384L, 385L, 388L, 389L, 390L,
         391L, 392L, 393L, 395L, 396L, 397L, 398L, 399L, 400L, 402L, 403L,
         404L, 406L, 407L, 409L, 410L, 411L, 413L, 415L, 416L, 417L, 418L,
         419L, 420L, 421L, 422L, 423L, 426L, 427L, 428L, 429L, 430L, 431L,
         434L, 435L, 437L, 438L, 439L, 440L, 441L, 442L, 443L, 444L, 445L,
         447L, 448L, 451L, 453L, 454L, 455L, 457L, 458L, 459L, 461L, 462L,
         463L, 464L, 465L, 467L, 468L, 472L, 473L, 474L, 476L, 477L, 478L,
         479L, 480L, 481L, 482L, 485L, 486L, 487L, 488L, 489L, 490L, 491L,
         492L, 493L, 494L, 496L, 499L, 501L, 503L, 505L, 506L, 507L, 510L,
         511L, 512L, 513L, 515L, 516L, 517L, 518L, 519L, 521L, 524L, 525L,
         528L, 529L, 532L, 535L, 536L, 537L, 539L, 540L, 545L, 546L, 548L,
         549L, 550L, 551L, 552L, 554L, 555L, 556L, 557L, 559L, 561L, 562L,
         564L, 566L, 567L, 569L, 570L, 571L, 572L, 573L, 574L, 575L, 577L,
         580L, 581L, 583L, 584L, 586L, 587L, 588L, 589L, 590L, 591L, 592L,
         594L, 595L, 596L, 602L, 603L, 604L, 606L, 608L, 609L, 610L, 611L,
         612L, 614L, 615L, 616L, 617L, 618L, 619L, 620L, 624L, 625L, 626L,
         627L, 628L, 630L, 631L, 632L, 635L, 636L, 637L, 639L, 640L, 641L,
         644L, 645L, 648L, 649L, 651L, 656L, 657L, 659L, 660L, 661L, 662L,
         663L, 664L, 665L, 666L, 667L, 669L, 671L, 672L, 674L, 675L, 678L,
         679L, 681L, 683L, 684L, 685L, 687L, 689L, 690L, 691L, 694L, 695L,
         696L, 697L, 698L, 699L, 700L, 701L, 704L, 705L, 706L, 707L, 708L,
         709L, 710L, 711L, 712L, 713L, 714L, 716L, 717L, 718L, 719L, 721L,
         722L, 723L, 725L, 727L, 728L, 730L, 732L, 735L, 736L, 738L, 739L,
         740L, 742L, 743L, 745L, 746L, 747L, 748L, 750L, 751L, 752L, 754L,
         755L, 756L, 760L, 761L, 762L, 763L, 765L, 769L, 770L, 773L, 774L,
         775L, 776L, 778L, 779L, 780L, 781L, 782L, 785L, 786L, 787L, 789L,
         790L, 793L, 796L, 797L, 798L, 799L, 800L, 801L, 803L, 804L, 805L,
         807L, 809L, 810L, 812L, 814L, 816L, 817L, 819L, 821L, 823L, 824L,
         825L, 828L, 830L, 832L, 833L, 834L, 835L, 836L, 837L, 838L, 839L,
         843L, 844L, 846L, 847L, 848L, 849L, 851L, 852L, 853L, 854L, 856L,
         858L, 860L, 861L, 862L, 863L, 864L, 865L, 867L, 868L, 872L, 873L,
         875L, 876L, 877L, 878L, 879L, 880L, 881L, 882L, 883L, 884L, 885L,
         886L, 887L, 888L, 889L, 893L, 895L, 897L, 898L, 899L, 902L, 903L,
         905L, 906L, 908L, 909L, 910L, 916L, 917L, 920L, 921L, 922L, 924L,
         925L, 926L, 928L, 929L, 931L, 932L, 933L, 934L, 935L, 936L, 937L,
         939L, 940L, 944L, 947L, 949L, 950L, 951L, 952L, 953L, 954L, 955L,
         956L, 957L, 958L, 959L, 960L, 961L, 962L, 963L, 964L, 965L, 966L,
         969L, 971L, 973L, 974L, 976L, 977L, 979L, 980L, 983L, 984L, 985L,
         986L, 987L, 988L, 989L, 990L, 994L, 995L, 996L, 999L)

data = xdf[rows, !(names(xdf) %in% c("time", "status"))]
target = xdf[rows, names(xdf) %in% c("time", "status")]

label = target[["time"]]
status = target[["status"]]

label[status != 1] = -1L * label[status != 1]
data = xgboost::xgb.DMatrix(
  data = as.matrix(data),
  label = label)

fit <- xgboost::xgb.train(
  data = data,
  tree_method = "hist",
  booster = "gbtree",
  objective = "survival:cox",
  nrounds = 57,
  eta = 0.9687533,
  max_depth = 2,
  eval_metric = "cox-nloglik"

pred <- predict(fit, data)

#>  [1] 114 121 159 213 230 231 241 246 252 260 269 295 310 323 361 615 656 658
#>  [1] Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf

Created on 2024-01-10 with reprex v2.0.2

Session info
#> R version 4.3.2 (2023-10-31)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS Sonoma 14.2.1
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib 
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> time zone: Europe/Berlin
#> tzcode source: internal
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> loaded via a namespace (and not attached):
#>  [1] utf8_1.2.4            future_1.33.0         generics_0.1.3       
#>  [4] distr6_1.8.4          lattice_0.22-5        listenv_0.9.0        
#>  [7] digest_0.6.33         magrittr_2.0.3        evaluate_0.23        
#> [10] grid_4.3.2            ooplah_0.2.0          fastmap_1.1.1        
#> [13] jsonlite_1.8.7        xgboost_1.7.5.1       Matrix_1.6-3         
#> [16] backports_1.4.1       survival_3.5-7        param6_0.2.4         
#> [19] fansi_1.0.5           scales_1.2.1          mlr3_0.17.0          
#> [22] codetools_0.2-19      palmerpenguins_0.1.1  cli_3.6.1            
#> [25] rlang_1.1.2           crayon_1.5.2          mlr3viz_0.6.1        
#> [28] parallelly_1.36.0     splines_4.3.2         munsell_0.5.0        
#> [31] reprex_2.0.2          withr_2.5.2           yaml_2.3.7           
#> [34] mlr3pipelines_0.5.0-1 tools_4.3.2           parallel_4.3.2       
#> [37] uuid_1.1-1            set6_0.2.6            checkmate_2.3.0      
#> [40] dplyr_1.1.3           colorspace_2.1-0      ggplot2_3.4.4        
#> [43] mlr3proba_0.5.7       globals_0.16.2        vctrs_0.6.4          
#> [46] R6_2.5.1              lifecycle_1.0.4       fs_1.6.3             
#> [49] dictionar6_0.1.3      mlr3misc_0.13.0       pkgconfig_2.0.3      
#> [52] pillar_1.9.0          gtable_0.3.4          data.table_1.14.8    
#> [55] glue_1.6.2            Rcpp_1.0.11           lgr_0.4.4            
#> [58] paradox_0.11.1        xfun_0.41             tibble_3.2.1         
#> [61] tidyselect_1.2.0      rstudioapi_0.15.0     knitr_1.45           
#> [64] htmltools_0.5.7       rmarkdown_2.25        compiler_4.3.2
