Skip to content

Instantly share code, notes, and snippets.

@sebastianrothbucher
Last active December 6, 2020 11:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sebastianrothbucher/d53a4671f602d864c5c69bad6c2746ca to your computer and use it in GitHub Desktop.
Save sebastianrothbucher/d53a4671f602d864c5c69bad6c2746ca to your computer and use it in GitHub Desktop.
Explaining nnet with LIME (and SHAP)
library(lime)
library(datasets)
library(nnet)
library(caret)
#View(iris)
train_index <- sample(1:nrow(iris), 0.8 * nrow(iris))
test_index <- setdiff(1:nrow(iris), train_index)
iris_net <- nnet(Species~., data = iris[train_index,], size = 20)
test_pred <- predict(iris_net, iris[test_index, c(1:4)], type = 'class')
test_conf <- confusionMatrix(factor(test_pred), iris$Species[test_index], mode = 'prec_recall')
#View(test_conf$byClass[,'F1'])
print(mean(test_conf$byClass[,'F1']))
model_type.nnet <- function(x, ...) 'classification'
predict_model.nnet <- function(x, newdata, type, ...) {
#print(type)
res <- predict(x, newdata, type = 'raw');
#print(res[1,])
data.frame(Response = res, stringsAsFactors = FALSE)
}
iris_expln <- lime(iris[train_index,], iris_net)
iris_lime <- lime::explain(iris[test_index[c(1:4)], c(1:4)], iris_expln, n_labels = 1, n_features = 4)
#View(iris_lime)
print(iris$Species[test_index[c(1:4)]])
plot_features(iris_lime)
# also SHAP (test_pred[3]: )
library(iml)
iris_shappred <- Predictor$new(data = iris[test_index, c(1:4)], model = iris_net)
iris_shap <- Shapley$new(iris_shappred, x.interest = iris[test_index[3], c(1:4)])
print(iris$Species[test_index[3]])
iris_shap$plot()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment