Skip to content

Instantly share code, notes, and snippets.

@tjvananne
Created July 25, 2017 19:57
Show Gist options
  • Save tjvananne/a9c7da5d121287d0c89a18b5c596050e to your computer and use it in GitHub Desktop.
Save tjvananne/a9c7da5d121287d0c89a18b5c596050e to your computer and use it in GitHub Desktop.
Xgboost Feature Importance
# xgboost feature importance ----------------------------------------------------------------
#' Author: Taylor Van Anne
#' 7/25/2017
#'
#' This script is a simple demonstration of:
#' 1) using cross validation to determine the optimal number of iterations for your xgboost model
#' 2) runs xgboost with that number of iterations
#' 3) extract the feature importance from the resulting model
library(xgboost)
set.seed(4)
# set up data
my_iris <- iris
just_features <- my_iris[, setdiff(names(my_iris), "Species")]
just_labels <- as.integer(as.factor(my_iris$Species)) - 1
# build a DMatrix
my_iris_dmat <- xgboost::xgb.DMatrix(as.matrix(just_features), label=just_labels)
# general params for both CV and real modeling -- change your objective function to match your situation
myxgb_params <- list(objective = "multi:softmax", num_class=3, nthread=1)
# list of additional parameters:
# http://xgboost.readthedocs.io/en/latest/parameter.html
# bonus material - difference between 'objective' and 'feval':
# https://stackoverflow.com/questions/34178287/difference-between-objective-and-feval-in-xgboost
# run cross validation
myxgb_cv <- xgboost::xgb.cv(params=myxgb_params, my_iris_dmat, nrounds=1000, early_stopping_rounds = 10, nfold=5)
# find best number of iterations from within the cross validation results
best_numb_iter <- which( min(myxgb_cv$evaluation_log$test_merror_mean) == myxgb_cv$evaluation_log$test_merror_mean)[1]
# run the model on all of "train" with same parameters as the crossvalidation
myxgb_model <- xgboost::xgboost(
params = myxgb_params,
data = my_iris_dmat,
nrounds = best_numb_iter,
print_every_n = 1
)
# plot the importance
myxgb_imp <- xgboost::xgb.importance(feature_names=colnames(my_iris_dmat), model=myxgb_model)
xgb.ggplot.importance(myxgb_imp)
xgb.ggplot.deepness(myxgb_model)
?xgb.train
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment