Instantly share code, notes, and snippets. richpauloo/README.md Last active Jul 28, 2019

Cumulative Variable Importance for Random Forest Models

Cumulative Variable Importance for Random Forest (RF) 🌲🌳 Models

Motivation

What does an interpretable RF visualization look like? Out-of-the-box 📦 RF implementations in R and Python compute variable importance over all trees, but how do we get there?

In other words, what would a cumulative variable importance for a RF look like?

Approach

The randomForest R package (to the best of my knowledge) doesn't record the individual variable importance for each CART in the forest. Instead, it supplies the overall (summarized) importance via importance(rf_model).

Thus, instead of fitting a RF of n trees, I fit n RFs of 1 tree each, and compute the cumulative %IncMSE. Then, I plot the forest tree-by-tree alongside the cumulative variable importance as the nth tree is added.

The code below is a minimal example with mtcars. Start by cloning https://github.com/richpauloo/reprtree and adding the file path in line 19.

This script takes ~ 1 min to run on my personal computer. Unless your display is very large, you may need to expand the final animation.

Next Steps

If you think this is cool and useful enough to be an R package 📦, if you'd use it to make your RF models more interpretable, or want to work on it with me, please let me know on Twitter (@RichPauloo), or at my email: richpauloo at gmail dot com.

I'm looking into the randomForest source code to see if I can add an option to save the importance data of each CART model, so that I can build functions that take randomForest objects directly, rather than needing to re-run the models tree-by-tree.

Thanks for your interest!

 # minimal example library(randomForest) library(ggplot2) library(dplyr) library(purrr) library(colormap) library(tree) library(plotrix) library(cowplot) library(gridGraphics) library(magick) library(forcats) # source repartree package for plotting individual trees generated by RF # clone repo at https://github.com/richpauloo/reprtree # and change the file path below to the directory containing the R files invisible( lapply( list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE), source ) ) # example data == mtcars df <- mtcars # number of trees to grow in random forest nn <- 100 # function to run nn CART models (single tree) run_rf <- function(rand_seed){ set.seed(rand_seed) one_tr = randomForest(mpg ~ ., data = df, importance = TRUE, ntree = 1) return(one_tr) } # list to store output of each model l <- lapply(1:nn, run_rf) # number of predictors in RF mod npred <- length(names(l[]\$forest\$xlevels)) # extract importance of each CART model, impdf <- map(l, importance) %>% map(as.data.frame) %>% map( ~ { .\$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>% bind_rows() %>% mutate(tree_num = rep(1:nn, each = npred)) # add tree number # summarised var imp tot_mse <- group_by(impdf, var) %>% summarise(`%IncMSE` = mean(`%IncMSE`)) %>% arrange(-`%IncMSE`) # ranked variables rv <- tot_mse\$var impdf\$var <- factor(impdf\$var, levels = rv) # vector of trees to plot # here I plot every 10 trees for speed, but this can be changed plt_vec <- c(1, seq(10, nn, 10)) # initalize lists for: varimp, trees, plot titles, and combined plots pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec)) for(i in seq_along(plt_vec)){ # cumulative variable importance with each tree's addition pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>% group_by(var) %>% summarise(mse = mean(`%IncMSE`)) %>% ggplot(aes(forcats::fct_rev(var), mse, fill=var)) + geom_col() + coord_flip(ylim = c(0, max(tot_mse\$`%IncMSE`))) + scale_fill_viridis_d() + labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable")+ theme_minimal() + theme(legend.position = "bottom") # make tree plots plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv) tl[[i]] <- recordPlot() # make plot titles pt[[i]] <- ggdraw() + draw_label(paste0("Tree ", plt_vec[i]), fontface='bold') # combine all plots with title bp[[i]] <- plot_grid(pt[[i]], plot_grid(tl[[i]], pl[[i]]), ncol=1, rel_heights=c(0.1, 1)) } # use magick to turn plots into a GIF. # WARNING: magick doens't handle hundreds of plots well in my experience # and it may be better to print them into a single PDF, then render the # GIF elsewhere. Also beware of temporary files that magick creates... # Also, this animation may be too large to fit in your viewer, so # be sure to expand it! img <- image_graph(1000, 600, res = 96) for(i in seq_along(plt_vec)){ print( bp[[i]] ) } dev.off() animation <- image_animate(img, fps = 2) print(animation) # save to working directory image_write(animation, "anim.gif") # uncomment and run to print to PDF and makethe gif elsewhere, # like https://ezgif.com/maker # pdf("all.pdf", width = 12, height = 7) # invisible(lapply(bp, print)) # dev.off()

JGDS01 commented May 5, 2019 • edited

 Hi - great code, thanks for sharing. one question: I'm having issues with plot.getTree - is this part of the reprtree package? Could be my mistake but cant seem to run the function. Thanks
Owner Author

richpauloo commented May 6, 2019 • edited

 Hi @JGDS01, plot.getTree is indeed part of the reprtree package, but I modified that function to take arguments for color scales, so you'll need to clone https://github.com/richpauloo/reprtree and source those files into R in order to get the plot.getTree that achieves this (see lines 14-22 above). I suspect that you might not be loading this function into R. Try list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE) and make sure your file path is correct in line 19 and that you're actually sourcing these files into R.
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.