Thomas Nagler 24 April, 2017
This vignette reproduces the results for the application section of the paper
Nagler, T. (2017). A generic approach to nonparametric function estimation with mixed data. arXiv:1704.07457
The goal is to diagnose diabetic retinopathy (a disease resulting from diabetes mellitus) from images of the retina. The retinal images have been preprocessed and a total of 19 features have been extracted. Three features are binary categories, six are integer valued count variables, and the remaining 10 features are continuous measurements. For more information about the pre-processing and features, we refer to Antal and Hajdu (2014).
devtools::install_github("tnagler/cctools")
devtools::install_github("tnagler/kdevine")
devtools::install_github("tnagler/jdify")
library(jdify)
library(tidyverse)
set.seed(1)
First we download the data set from the UCI repository.
URL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/00329/messidor_features.arff"
dat <- as_tibble(foreign::read.arff(URL))
To let jdify()
know which variables to treat as discrete, we have to declare them as ordered
. We simply treat all integer values variables as discrete.
dat[-20] <- dat[-20] %>%
map_if(~ all(. == round(.)), ~ ordered(., seq.int(min(.), max(.))))
dat
## # A tibble: 1,151 × 20
## `0` `1` `2` `3` `4` `5` `6` `7` `8` `9`
## <ord> <ord> <ord> <ord> <ord> <ord> <ord> <ord> <dbl> <dbl>
## 1 1 1 22 22 22 19 18 14 49.89576 17.775994
## 2 1 1 24 24 22 18 16 13 57.70994 23.799994
## 3 1 1 62 60 59 54 47 33 55.83144 27.993933
## 4 1 1 55 53 53 50 43 31 40.46723 18.445954
## 5 1 1 44 44 44 41 39 27 18.02625 8.570709
## 6 1 1 44 43 41 41 37 29 28.35640 6.935636
## 7 1 0 29 29 29 27 25 16 15.44840 9.113819
## 8 1 1 6 6 6 6 2 1 20.67965 9.497786
## 9 1 1 22 21 18 15 13 10 66.69193 23.545543
## 10 1 1 79 75 73 71 64 47 22.14178 10.054384
## # ... with 1,141 more rows, and 10 more variables: `10` <dbl>, `11` <dbl>,
## # `12` <dbl>, `13` <dbl>, `14` <dbl>, `15` <dbl>, `16` <dbl>,
## # `17` <dbl>, `18` <ord>, Class <fctr>
We do 10-fold cross validation on the two joint density classifiers based on the np
and kdevine
packages.
ncores <- parallel::detectCores()
cv_models <- list(liracine = "np", vine = "kdevine") %>%
map(~ cv_jdify(Class ~ ., data = dat, .x, cores = ncores, folds = 10))
Next, we extract (out-of-sample) class probabilities and calculate the performance measures:
thresh <- seq(0, 1, length.out = 1000)
results <- cv_models %>%
map("cv_probs") %>%
map(~ assess_clsfyr(.x[, "1"], dat$Class == 1, c("FPR", "TPR"), thresh)) %>%
bind_rows(.id = "method") %>%
spread(measure, value) %>%
group_by(method) %>%
mutate(AUC = get_auc(cbind(FPR = FPR, TPR = TPR))) %>%
ungroup()
Finally, we can create the ROC plot.
results %>%
mutate(method = paste0(method, " (AUC = ", round(AUC, 2), ") "),
method = reorder(method, -AUC)) %>%
ggplot(aes(FPR, TPR, col = method, linetype = method, AUC)) +
geom_line(size = 0.7) +
theme(legend.position = "top", legend.key.width = unit(7.5, "mm")) +
labs(linetype = "", col = "")
sessionInfo()
## R version 3.4.0 (2017-04-21)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Linux Mint 18.1
##
## Matrix products: default
## BLAS: /usr/lib/openblas-base/libblas.so.3
## LAPACK: /usr/lib/libopenblasp-r0.2.18.so
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
## [5] LC_MONETARY=de_DE.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=de_DE.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=de_DE.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] dplyr_0.5.0 purrr_0.2.2 readr_1.1.0 tidyr_0.6.1
## [5] tibble_1.3.0 ggplot2_2.2.1 tidyverse_1.1.1 jdify_0.1.0.9000
##
## loaded via a namespace (and not attached):
## [1] reshape2_1.4.2 haven_1.0.0 lattice_0.20-35
## [4] colorspace_1.3-2 htmltools_0.3.5 yaml_2.1.14
## [7] foreign_0.8-67 withr_1.0.2 DBI_0.6-1
## [10] modelr_0.1.0 readxl_1.0.0 foreach_1.4.3
## [13] plyr_1.8.4 stringr_1.2.0 munsell_0.4.3
## [16] gtable_0.2.0 cellranger_1.1.0 rvest_0.3.2
## [19] devtools_1.12.0 codetools_0.2-15 psych_1.7.3.21
## [22] memoise_1.1.0 evaluate_0.10 labeling_0.3
## [25] knitr_1.15.1 forcats_0.2.0 doParallel_1.0.10
## [28] parallel_3.4.0 curl_2.5 broom_0.4.2
## [31] Rcpp_0.12.10 backports_1.0.5 scales_0.4.1
## [34] jsonlite_1.4 cctools_0.1.0 mnormt_1.5-5
## [37] hms_0.3 digest_0.6.12 stringi_1.1.5
## [40] grid_3.4.0 rprojroot_1.2 tools_3.4.0
## [43] magrittr_1.5 lazyeval_0.2.0 qrng_0.0-3
## [46] xml2_1.1.1 lubridate_1.6.0 assertthat_0.2.0
## [49] rmarkdown_1.4 httr_1.2.1 iterators_1.0.8
## [52] R6_2.2.0 nlme_3.1-131 git2r_0.18.0
## [55] compiler_3.4.0