Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
jdify: Rethinopathy classification

jdify: Retinopathy classification

Thomas Nagler 24 April, 2017

This vignette reproduces the results for the application section of the paper

Nagler, T. (2017). A generic approach to nonparametric function estimation with mixed data. arXiv:1704.07457

The goal is to diagnose diabetic retinopathy (a disease resulting from diabetes mellitus) from images of the retina. The retinal images have been preprocessed and a total of 19 features have been extracted. Three features are binary categories, six are integer valued count variables, and the remaining 10 features are continuous measurements. For more information about the pre-processing and features, we refer to Antal and Hajdu (2014).

Required libraries

devtools::install_github("tnagler/cctools")
devtools::install_github("tnagler/kdevine")
devtools::install_github("tnagler/jdify")
library(jdify)
library(tidyverse)
set.seed(1)

Data download and preparation

First we download the data set from the UCI repository.

URL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/00329/messidor_features.arff"
dat <- as_tibble(foreign::read.arff(URL))

To let jdify() know which variables to treat as discrete, we have to declare them as ordered. We simply treat all integer values variables as discrete.

dat[-20] <- dat[-20] %>% 
    map_if(~ all(. == round(.)), ~ ordered(., seq.int(min(.), max(.))))
dat
## # A tibble: 1,151 × 20
##      `0`   `1`   `2`   `3`   `4`   `5`   `6`   `7`      `8`       `9`
##    <ord> <ord> <ord> <ord> <ord> <ord> <ord> <ord>    <dbl>     <dbl>
## 1      1     1    22    22    22    19    18    14 49.89576 17.775994
## 2      1     1    24    24    22    18    16    13 57.70994 23.799994
## 3      1     1    62    60    59    54    47    33 55.83144 27.993933
## 4      1     1    55    53    53    50    43    31 40.46723 18.445954
## 5      1     1    44    44    44    41    39    27 18.02625  8.570709
## 6      1     1    44    43    41    41    37    29 28.35640  6.935636
## 7      1     0    29    29    29    27    25    16 15.44840  9.113819
## 8      1     1     6     6     6     6     2     1 20.67965  9.497786
## 9      1     1    22    21    18    15    13    10 66.69193 23.545543
## 10     1     1    79    75    73    71    64    47 22.14178 10.054384
## # ... with 1,141 more rows, and 10 more variables: `10` <dbl>, `11` <dbl>,
## #   `12` <dbl>, `13` <dbl>, `14` <dbl>, `15` <dbl>, `16` <dbl>,
## #   `17` <dbl>, `18` <ord>, Class <fctr>

Model fitting and performance assessment

We do 10-fold cross validation on the two joint density classifiers based on the np and kdevine packages.

ncores <- parallel::detectCores()
cv_models <- list(liracine = "np", vine = "kdevine") %>% 
    map(~ cv_jdify(Class ~ ., data = dat, .x, cores = ncores, folds = 10))

Next, we extract (out-of-sample) class probabilities and calculate the performance measures:

thresh <- seq(0, 1, length.out = 1000)
results <- cv_models %>%
    map("cv_probs") %>%
    map(~ assess_clsfyr(.x[, "1"], dat$Class == 1, c("FPR", "TPR"), thresh)) %>%
    bind_rows(.id = "method") %>%
    spread(measure, value) %>%
    group_by(method) %>%
    mutate(AUC = get_auc(cbind(FPR = FPR, TPR = TPR))) %>%
    ungroup()

ROC plot

Finally, we can create the ROC plot.

results %>% 
    mutate(method = paste0(method, " (AUC = ", round(AUC, 2), ")  "),
           method = reorder(method, -AUC)) %>%
    ggplot(aes(FPR, TPR, col = method, linetype = method, AUC)) +
    geom_line(size = 0.7) +
    theme(legend.position = "top", legend.key.width = unit(7.5, "mm")) +
    labs(linetype = "", col = "")

Environment

sessionInfo()
## R version 3.4.0 (2017-04-21)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Linux Mint 18.1
## 
## Matrix products: default
## BLAS: /usr/lib/openblas-base/libblas.so.3
## LAPACK: /usr/lib/libopenblasp-r0.2.18.so
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
##  [5] LC_MONETARY=de_DE.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=de_DE.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=de_DE.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] dplyr_0.5.0      purrr_0.2.2      readr_1.1.0      tidyr_0.6.1     
## [5] tibble_1.3.0     ggplot2_2.2.1    tidyverse_1.1.1  jdify_0.1.0.9000
## 
## loaded via a namespace (and not attached):
##  [1] reshape2_1.4.2    haven_1.0.0       lattice_0.20-35  
##  [4] colorspace_1.3-2  htmltools_0.3.5   yaml_2.1.14      
##  [7] foreign_0.8-67    withr_1.0.2       DBI_0.6-1        
## [10] modelr_0.1.0      readxl_1.0.0      foreach_1.4.3    
## [13] plyr_1.8.4        stringr_1.2.0     munsell_0.4.3    
## [16] gtable_0.2.0      cellranger_1.1.0  rvest_0.3.2      
## [19] devtools_1.12.0   codetools_0.2-15  psych_1.7.3.21   
## [22] memoise_1.1.0     evaluate_0.10     labeling_0.3     
## [25] knitr_1.15.1      forcats_0.2.0     doParallel_1.0.10
## [28] parallel_3.4.0    curl_2.5          broom_0.4.2      
## [31] Rcpp_0.12.10      backports_1.0.5   scales_0.4.1     
## [34] jsonlite_1.4      cctools_0.1.0     mnormt_1.5-5     
## [37] hms_0.3           digest_0.6.12     stringi_1.1.5    
## [40] grid_3.4.0        rprojroot_1.2     tools_3.4.0      
## [43] magrittr_1.5      lazyeval_0.2.0    qrng_0.0-3       
## [46] xml2_1.1.1        lubridate_1.6.0   assertthat_0.2.0 
## [49] rmarkdown_1.4     httr_1.2.1        iterators_1.0.8  
## [52] R6_2.2.0          nlme_3.1-131      git2r_0.18.0     
## [55] compiler_3.4.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.