Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Cumulative Variable Importance for Random Forest Models

Cumulative Variable Importance for Random Forest (RF) 🌲🌳 Models

Motivation

What does an interpretable RF visualization look like? Out-of-the-box 📦 RF implementations in R and Python compute variable importance over all trees, but how do we get there?

In other words, what would a cumulative variable importance for a RF look like?

Approach

The randomForest R package (to the best of my knowledge) doesn't record the individual variable importance for each CART in the forest. Instead, it supplies the overall (summarized) importance via importance(rf_model).

Thus, instead of fitting a RF of n trees, I fit n RFs of 1 tree each, and compute the cumulative %IncMSE. Then, I plot the forest tree-by-tree alongside the cumulative variable importance as the nth tree is added.

The code below is a minimal example with mtcars. Start by cloning https://github.com/richpauloo/reprtree and adding the file path in line 19.

This script takes ~ 1 min to run on my personal computer. Unless your display is very large, you may need to expand the final animation.

Next Steps

If you think this is cool and useful enough to be an R package 📦, if you'd use it to make your RF models more interpretable, or want to work on it with me, please let me know on Twitter (@RichPauloo), or at my email: richpauloo at gmail dot com.

I'm looking into the randomForest source code to see if I can add an option to save the importance data of each CART model, so that I can build functions that take randomForest objects directly, rather than needing to re-run the models tree-by-tree.

Thanks for your interest!

# minimal example
library(randomForest)
library(ggplot2)
library(dplyr)
library(purrr)
library(colormap)
library(tree)
library(plotrix)
library(cowplot)
library(gridGraphics)
library(magick)
library(forcats)
# source repartree package for plotting individual trees generated by RF
# clone repo at https://github.com/richpauloo/reprtree
# and change the file path below to the directory containing the R files
invisible(
lapply(
list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE),
source
)
)
# example data == mtcars
df <- mtcars
# number of trees to grow in random forest
nn <- 100
# function to run nn CART models (single tree)
run_rf <- function(rand_seed){
set.seed(rand_seed)
one_tr = randomForest(mpg ~ .,
data = df,
importance = TRUE,
ntree = 1)
return(one_tr)
}
# list to store output of each model
l <- lapply(1:nn, run_rf)
# number of predictors in RF mod
npred <- length(names(l[[1]]$forest$xlevels))
# extract importance of each CART model,
impdf <- map(l, importance) %>%
map(as.data.frame) %>%
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>%
bind_rows() %>%
mutate(tree_num = rep(1:nn, each = npred)) # add tree number
# summarised var imp
tot_mse <- group_by(impdf, var) %>%
summarise(`%IncMSE` = mean(`%IncMSE`)) %>%
arrange(-`%IncMSE`)
# ranked variables
rv <- tot_mse$var
impdf$var <- factor(impdf$var, levels = rv)
# vector of trees to plot
# here I plot every 10 trees for speed, but this can be changed
plt_vec <- c(1, seq(10, nn, 10))
# initalize lists for: varimp, trees, plot titles, and combined plots
pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec))
for(i in seq_along(plt_vec)){
# cumulative variable importance with each tree's addition
pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>%
group_by(var) %>%
summarise(mse = mean(`%IncMSE`)) %>%
ggplot(aes(forcats::fct_rev(var), mse, fill=var)) +
geom_col() +
coord_flip(ylim = c(0, max(tot_mse$`%IncMSE`))) +
scale_fill_viridis_d() +
labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable")+
theme_minimal() +
theme(legend.position = "bottom")
# make tree plots
plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv)
tl[[i]] <- recordPlot()
# make plot titles
pt[[i]] <- ggdraw() + draw_label(paste0("Tree ", plt_vec[i]),
fontface='bold')
# combine all plots with title
bp[[i]] <- plot_grid(pt[[i]], plot_grid(tl[[i]], pl[[i]]),
ncol=1, rel_heights=c(0.1, 1))
}
# use magick to turn plots into a GIF.
# WARNING: magick doens't handle hundreds of plots well in my experience
# and it may be better to print them into a single PDF, then render the
# GIF elsewhere. Also beware of temporary files that magick creates...
# Also, this animation may be too large to fit in your viewer, so
# be sure to expand it!
img <- image_graph(1000, 600, res = 96)
for(i in seq_along(plt_vec)){ print( bp[[i]] ) }
dev.off()
animation <- image_animate(img, fps = 2)
print(animation)
# save to working directory
image_write(animation, "anim.gif")
# uncomment and run to print to PDF and makethe gif elsewhere,
# like https://ezgif.com/maker
# pdf("all.pdf", width = 12, height = 7)
# invisible(lapply(bp, print))
# dev.off()
@JGDS01

This comment has been minimized.

Copy link

commented May 5, 2019

Hi - great code, thanks for sharing.

one question: I'm having issues with plot.getTree - is this part of the reprtree package? Could be my mistake but cant seem to run the function.

Thanks

@richpauloo

This comment has been minimized.

Copy link
Owner Author

commented May 6, 2019

Hi @JGDS01, plot.getTree is indeed part of the reprtree package, but I modified that function to take arguments for color scales, so you'll need to clone https://github.com/richpauloo/reprtree and source those files into R in order to get the plot.getTree that achieves this (see lines 14-22 above). I suspect that you might not be loading this function into R. Try list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE) and make sure your file path is correct in line 19 and that you're actually sourcing these files into R.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.