Skip to content

Instantly share code, notes, and snippets.

@richpauloo
Last active December 24, 2019 01:43
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save richpauloo/8742761e131bd6a0486a40df2b9d1f01 to your computer and use it in GitHub Desktop.
Save richpauloo/8742761e131bd6a0486a40df2b9d1f01 to your computer and use it in GitHub Desktop.
Cumulative Variable Importance for Random Forest Models

Cumulative Variable Importance for Random Forest (RF) 🌲🌳 Models

Motivation

What does an interpretable RF visualization look like? Out-of-the-box 📦 RF implementations in R and Python compute variable importance over all trees, but how do we get there?

In other words, what would a cumulative variable importance for a RF look like?

Approach

The randomForest R package (to the best of my knowledge) doesn't record the individual variable importance for each CART in the forest. Instead, it supplies the overall (summarized) importance via importance(rf_model).

Thus, instead of fitting a RF of n trees, I fit n RFs of 1 tree each, and compute the cumulative %IncMSE. Then, I plot the forest tree-by-tree alongside the cumulative variable importance as the nth tree is added.

The code below is a minimal example with mtcars. Start by cloning https://github.com/richpauloo/reprtree and adding the file path in line 19.

This script takes ~ 1 min to run on my personal computer. Unless your display is very large, you may need to expand the final animation.

Next Steps

If you think this is cool and useful enough to be an R package 📦, if you'd use it to make your RF models more interpretable, or want to work on it with me, please let me know on Twitter (@RichPauloo), or at my email: richpauloo at gmail dot com.

I'm looking into the randomForest source code to see if I can add an option to save the importance data of each CART model, so that I can build functions that take randomForest objects directly, rather than needing to re-run the models tree-by-tree.

Thanks for your interest!

# minimal example
library(randomForest)
library(ggplot2)
library(dplyr)
library(purrr)
library(colormap)
library(tree)
library(plotrix)
library(cowplot)
library(gridGraphics)
library(magick)
library(forcats)
# source repartree package for plotting individual trees generated by RF
# clone repo at https://github.com/richpauloo/reprtree
# and change the file path below to the directory containing the R files
invisible(
lapply(
list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE),
source
)
)
# example data == mtcars
df <- mtcars
# number of trees to grow in random forest
nn <- 500
# function to run nn CART models (single tree)
run_rf <- function(rand_seed){
set.seed(rand_seed)
one_tr = randomForest(mpg ~ .,
data = df,
importance = TRUE,
ntree = 1)
return(one_tr)
}
# list to store output of each model
l <- lapply(1:nn, run_rf)
# number of predictors in RF mod
npred <- length(names(l[[1]]$forest$xlevels))
# extract importance of each CART model,
impdf <- map(l, importance) %>%
map(as.data.frame) %>%
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>%
bind_rows() %>%
mutate(tree_num = rep(1:nn, each = npred)) # add tree number
# summarised var imp
tot_mse <- group_by(impdf, var) %>%
summarise(`%IncMSE` = mean(`%IncMSE`)) %>%
arrange(-`%IncMSE`)
# ranked variables
rv <- tot_mse$var
impdf$var <- factor(impdf$var, levels = rv)
# vector of trees to plot
# here I plot every 10 trees for speed, but this can be changed
plt_vec <- c(1, seq(10, nn, 10))
# initalize lists for: varimp, trees, plot titles, and combined plots
pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec))
for(i in seq_along(plt_vec)){
# cumulative variable importance with each tree's addition
pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>%
group_by(var) %>%
summarise(mse = mean(`%IncMSE`)) %>%
ggplot(aes(forcats::fct_rev(var), mse, fill=var)) +
geom_col() +
coord_flip(ylim = c(0, max(tot_mse$`%IncMSE`))) +
scale_fill_viridis_d() +
labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable",
title = paste0("Tree ", plt_vec[i])) +
theme_minimal() +
theme(legend.position = "bottom",
plot.title = element_text(size=25))
# make tree plots
plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv)
tl[[i]] <- recordPlot()
# make plot titles
pt[[i]] <- ggdraw() +
draw_label(
paste0("Tree ", plt_vec[i]),
fontface = 'bold',
x = 0,
hjust = 0
) +
theme(
# add margin on the left of the drawing canvas,
# so title is aligned with left edge of first plot
plot.margin = margin(0, 0, 0, 7)
)
# combine all plots with title
bp[[i]] <- plot_grid(pl[[i]], tl[[i]])
}
# use magick to turn plots into a GIF.
# WARNING: magick doens't handle hundreds of plots well in my experience
# and it may be better to print them into a single PDF, then render the
# GIF elsewhere. Also beware of temporary files that magick creates...
# Also, this animation may be too large to fit in your viewer, so
# be sure to expand it!
# img <- image_graph(1000, 600, res = 96)
# for(i in seq_along(plt_vec)){ print( bp[[i]] ) }
# dev.off()
# animation <- image_animate(img, fps = 2)
# print(animation)
#
# # save to working directory
# image_write(animation, "anim.gif")
# uncomment and run to print to PDF and makethe gif elsewhere,
# like https://ezgif.com/maker
pdf("all.pdf", width = 12, height = 7)
invisible(lapply(bp, print))
dev.off()
@richpauloo
Copy link
Author

richpauloo commented May 6, 2019

Hi @JGDS01, plot.getTree is indeed part of the reprtree package, but I modified that function to take arguments for color scales, so you'll need to clone https://github.com/richpauloo/reprtree and source those files into R in order to get the plot.getTree that achieves this (see lines 14-22 above). I suspect that you might not be loading this function into R. Try list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE) and make sure your file path is correct in line 19 and that you're actually sourcing these files into R.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment