Created
June 18, 2022 11:58
-
-
Save RaphaelS1/fba42d6cf05dd01489186e5c014afe8e to your computer and use it in GitHub Desktop.
shaps.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Get SHAPs from a trained xgboost model | |
#' @param model trained xgboost model | |
#' @param data new data for calculating shaps | |
get_shaps <- function(model, data) { | |
p = predict(model, data, | |
predcontrib = T, approxcontrib = F) | |
p = p[, -ncol(p)] | |
shaps = reshape2::melt(p) | |
vars = reshape2::melt(data) | |
relTruth = reshape2::melt(apply(data, 2, function(x) # nolint | |
(x - min(x, na.rm = TRUE)) / (max(x, na.rm = TRUE) - min(x, na.rm = TRUE)))) | |
shap.frame = cbind(shaps, relTruth$value, vars$value) # nolint | |
colnames(shap.frame) = c("id", "Var", "Shap", "relTruth", "Truth") | |
return(shap.frame) | |
} | |
#' Get interaction SHAPs from a trained xgboost model | |
#' @param model trained xgboost model | |
#' @param data new data for calculating shaps | |
get_interaction_shaps <- function(model, data) { | |
p = predict(model, data, | |
predinteraction = TRUE) | |
p = p[, colnames(data), colnames(data)] | |
p = array(c(as.numeric(p), as.numeric(data)), dim = c(dim(p)[1:2], dim(p)[3] + 1), | |
dimnames = list(NULL, colnames(data), c(colnames(data), "Truth"))) | |
return(p) | |
} | |
#' Plot SHAPs for a particular feature | |
#' @param shaps From `get_shaps` | |
#' @param feature Named feature to plot shap for | |
plot_shap_feature <- function(shaps, feature, | |
type = c("b", "p", "s"), | |
se = TRUE, xlab = feature, | |
smooth_color = "blue", smooth_fill = "lightblue", | |
smooth_lwd = 0.5, smooth_level = 0.95, | |
points_color = "black") { | |
Var = Truth = Shap = NULL | |
shaps = shaps[stats::complete.cases(shaps), ] | |
type = match.arg(type) | |
if (type == "b") { | |
ggplot2::ggplot(subset(shaps, Var == feature), | |
ggplot2::aes(x = Truth, y = Shap)) + | |
ggplot2::geom_point(color = points_color) + | |
ggplot2::geom_smooth(se = se, fill = smooth_fill, color = smooth_color, | |
level = smooth_level) + | |
ggplot2::labs(x = xlab, y = "SHAP Value") + | |
ggplot2::theme_minimal() + | |
ggplot2::theme( | |
panel.grid = ggplot2::element_blank(), | |
axis.line.x = ggplot2::element_line(), | |
axis.line.y = ggplot2::element_line()) | |
} else if (type == "p") { | |
ggplot2::ggplot(subset(shaps, Var == feature), | |
ggplot2::aes(x = Truth, y = Shap)) + | |
ggplot2::geom_point(color = points_color) + | |
ggplot2::labs(x = xlab, y = "SHAP Value") + theme_minimal() + | |
ggplot2::theme(panel.grid = ggplot2::element_blank(), | |
axis.line.x = ggplot2::element_line(), | |
axis.line.y = ggplot2::element_line()) | |
} else { | |
ggplot2::ggplot(subset(shaps, Var == feature), | |
ggplot2::aes(x = Truth, y = Shap)) + | |
ggplot2::geom_smooth(se = se, lwd = smooth_lwd, fill = smooth_fill, color = smooth_color, | |
level = smooth_level) + | |
ggplot2::labs(x = xlab, y = "SHAP Value") + | |
ggplot2::theme_minimal() + | |
ggplot2::theme(panel.grid = ggplot2::element_blank(), | |
axis.line.x = ggplot2::element_line(), | |
axis.line.y = ggplot2::element_line()) | |
} | |
} | |
plot_shap_summary <- function(shaps, low = "blue", high = "red", | |
jitter_w = 0, jitter_h = 0.2, | |
legend.h = 25, legend.w = 0.5, # nolint | |
show.mean = TRUE, mean.abs = TRUE, mean.round = 4, # nolint | |
show.legend = TRUE, # nolint | |
ord = NULL, violin = FALSE) { | |
Shap = Var = relTruth = x = y = NULL # nolint | |
shaps = shaps[stats::complete.cases(shaps), ] | |
if (is.null(ord)) { | |
ord = stats::aggregate(Shap ~ Var, data = shaps, function(x) mean(abs(x))) | |
shaps$Var = factor(shaps$Var, levels = as.character(ord[order(ord$Shap), 1])) | |
} else { | |
shaps$Var = factor(shaps$Var, levels = ord) | |
} | |
shaps[shaps$id == "NA", c(1, 3:5)] = NA | |
p = ggplot2::ggplot(shaps, ggplot2::aes(x = Shap, y = Var, colour = relTruth)) + | |
ggplot2::geom_vline(xintercept = 0, colour = "darkgray") + | |
ggplot2::geom_point() + ggplot2::geom_jitter(width = jitter_w, height = jitter_h) + | |
ggplot2::theme_minimal() + | |
ggplot2::theme(axis.title.y = ggplot2::element_blank(), | |
panel.grid = ggplot2::element_blank(), | |
axis.line.x = ggplot2::element_line(), | |
axis.text.y = ggplot2::element_text(size = 12), | |
legend.text = ggplot2::element_text(size = 11), | |
axis.text.x = ggplot2::element_text(size = 10)) + | |
ggplot2::labs(x = "SHAP Value") | |
if (show.legend) { | |
p = p + ggplot2::scale_color_gradient(low = low, high = high, | |
n.breaks = 2L, | |
labels = c("Low", "High"), | |
guide = ggplot2::guide_colorbar(barwidth = legend.w, barheight = legend.h, | |
title.position = "right", | |
title = "Feature Value", | |
title.theme = ggplot2::element_text(angle = 90, | |
hjust = 0.5))) | |
} else { | |
p = p + ggplot2::scale_color_gradient(low = low, high = high) + | |
ggplot2::theme(legend.position = "none") | |
} | |
if (violin) { | |
p = p + geom_violin(alpha = 0, lwd = 1) | |
} | |
if (show.mean) { | |
if (mean.abs) { | |
shapmean = stats::aggregate(Shap ~ Var, data = shaps, function(x) mean(abs(x)))[, 2] | |
} else { | |
shapmean = stats::aggregate(Shap ~ Var, data = shaps, mean)[, 2] | |
} | |
p = p + ggplot2::geom_text(aes(x = x, y = y, colour = NULL), | |
data = data.frame(x = max(shaps$Shap) + 0.5, y = seq_along(shapmean)), | |
label = format(round(shapmean, mean.round), nsmall = mean.round)) | |
} | |
p | |
} | |
plot_shap_importance <- function(shaps, relative = TRUE, bar_width = 0.4, | |
low = "blue", high = "red", order = TRUE, | |
legend.position = "none", # nolint | |
legend.w = 0.5, legend.h = 10, # nolint | |
legend_limit = NULL, xlim = NULL, | |
show.mean = TRUE, mean.round = 4) { # nolint | |
# fix visible bindings | |
Shap = Var = Mean = NULL | |
shaps = shaps[stats::complete.cases(shaps), ] | |
agg = cbind(stats::aggregate(Shap ~ Var, data = shaps, function(x) mean(abs(x))), | |
Mean = stats::aggregate(Shap ~ Var, data = shaps, mean)[, 2]) | |
if (relative) { | |
agg$Shap = (agg$Shap / sum(agg$Shap)) * 100 | |
} | |
if (order) { | |
agg = agg[order(agg$Shap, decreasing = FALSE), ] | |
agg$Var = factor(agg$Var, levels = agg$Var) | |
} | |
if (is.null(legend_limit)) { | |
legend_limit = range(agg$Mean) | |
} | |
if (is.null(xlim)) { | |
xlim = c(0, (max(agg$Shap) - (max(agg$Shap) %% 10)) + 10) | |
} | |
p = ggplot2::ggplot(agg, | |
ggplot2::aes(x = Shap, y = Var, fill = Mean)) + | |
ggplot2::geom_bar(stat = "identity", width = bar_width) + | |
ggplot2::scale_fill_gradient(low = low, high = high, n.breaks = 3, | |
limits = legend_limit, | |
guide = ggplot2::guide_colorbar(barwidth = legend.w, barheight = legend.h, | |
title.position = "right", | |
title = "Mean SHAP Value", | |
title.theme = ggplot2::element_text(angle = 90, | |
hjust = 0.5))) + | |
ggplot2::theme_minimal() + | |
ggplot2::theme(axis.title.y = ggplot2::element_blank(), | |
panel.grid = ggplot2::element_blank(), | |
axis.line.y = ggplot2::element_line(), | |
axis.text.y = ggplot2::element_text(size = 12), | |
legend.text = ggplot2::element_text(size = 11), | |
axis.text.x = ggplot2::element_text(size = 10), | |
axis.ticks = ggplot2::element_line(), | |
legend.position = legend.position) + | |
ggplot2::xlim(xlim) + | |
ggplot2::labs(x = "Feature Contribution (%)") | |
if (show.mean) { | |
p = p + ggplot2::geom_text(label = format(round(agg$Mean, mean.round), nsmall = mean.round), | |
hjust = -0.4) | |
} | |
p | |
} | |
plot_shap_sidebyside_importance <- function(shaps, bar_width = 0.4, ylim = NULL, # nolint | |
legend.position = "top") { # nolint | |
# fix visible binding | |
Shap = Var = Data = NULL | |
agg <- data.frame() | |
for (i in seq_along(shaps)) { | |
dat <- shaps[[i]][stats::complete.cases(shaps[[i]]), ] | |
dat[dat == "NA", c(1, 3:5)] <- NA | |
shapi <- cbind(Data = names(shaps)[[i]], | |
stats::aggregate(Shap ~ Var, data = dat, function(x) mean(abs(x)), | |
na.action = stats::na.pass)) | |
shapi$Shap <- (shapi$Shap / sum(shapi$Shap, na.rm = TRUE)) * 100 | |
agg <- rbind(agg, shapi) | |
} | |
if (is.null(ylim)) { | |
ylim = c(0, (max(agg$Shap, na.rm = TRUE) - (max(agg$Shap, na.rm = TRUE) %% 10)) + 10) | |
} | |
ord = order(agg[agg$Data == names(shaps)[[1]], "Shap"], decreasing = TRUE) | |
agg$Var = factor(agg$Var, levels = levels(agg$Var)[ord]) | |
ggplot2::ggplot(agg, | |
ggplot2::aes(y = Shap, x = Var, fill = Data, group = Data)) + | |
ggplot2::geom_bar(stat = "identity", width = bar_width, position = "dodge", | |
color = "black") + | |
ggplot2::geom_text(label = round(agg$Shap), position = ggplot2::position_dodge(bar_width), | |
vjust = -1) + | |
ggplot2::theme_minimal() + | |
ggplot2::theme(axis.title.x = ggplot2::element_blank(), | |
panel.grid = ggplot2::element_blank(), | |
axis.line.y = ggplot2::element_line(), | |
axis.text.x = ggplot2::element_text(size = 12), | |
legend.text = ggplot2::element_text(size = 11), | |
axis.text.y = ggplot2::element_text(size = 10), | |
axis.ticks = ggplot2::element_line(), | |
legend.title = ggplot2::element_blank(), | |
legend.position = legend.position, | |
legend.box = "horizontal") + | |
ggplot2::ylim(ylim) + | |
ggplot2::labs(y = "Feature Contribution (%)") | |
} | |
plot_shap_interaction <- function(shaps, shap_feature, intX_feature, # nolint | |
intZ_feature = NULL, type = c("b", "p", "s"), # nolint | |
se = TRUE, smooth_color = "blue", smooth_fill = "lightblue", | |
smooth_lwd = 0.5, smooth_level = 0.95, | |
low = "blue", high = "red") { | |
XTruth = SHAPS = ZTruth = NULL | |
type = match.arg(type) | |
int_shaps = shaps[, intX_feature, shap_feature] | |
Xtruth = shaps[, intX_feature, "Truth"] | |
if (!is.null(intZ_feature)) { | |
Ztruth = shaps[, intZ_feature, "Truth"] | |
} else { | |
Ztruth = Xtruth | |
} | |
data = data.frame(SHAPS = int_shaps, XTruth = Xtruth, ZTruth = Ztruth) | |
p = ggplot2::ggplot(data, ggplot2::aes(x = XTruth, y = SHAPS, color = ZTruth)) + | |
ggplot2::scale_color_gradient(low = low, high = high, n.breaks = 3, | |
guide = ggplot2::guide_colorbar(title = intZ_feature, | |
title.position = "right", | |
title.theme = ggplot2::element_text(angle = 90, | |
hjust = 0.5))) + | |
ggplot2::theme_minimal() + | |
ggplot2::theme(panel.grid = ggplot2::element_blank(), | |
axis.line.y = ggplot2::element_line(), | |
axis.line.x = ggplot2::element_line(), | |
axis.ticks = ggplot2::element_line()) + | |
ggplot2::labs(x = intX_feature, | |
y = paste("SHAP Interaction Values\n", shap_feature, "X", intX_feature)) | |
if (type == "p" | type == "b") { | |
p = p + ggplot2::geom_point() | |
} | |
if (type == "s" | type == "b") { | |
p = p + ggplot2::geom_smooth(se = se, color = smooth_color, fill = smooth_fill, | |
lwd = smooth_lwd, level = smooth_level) | |
} | |
p | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment