Skip to content

Instantly share code, notes, and snippets.

@quartox
Created April 14, 2016 16:15
Show Gist options
  • Save quartox/08d687bd391bd6942fcbf59031df0b52 to your computer and use it in GitHub Desktop.
Save quartox/08d687bd391bd6942fcbf59031df0b52 to your computer and use it in GitHub Desktop.
Using plotly to create 3d partial dependence plots for variable interactions in a gbm model.
suppressMessages(library("gbm"))
suppressMessages(library("plotly"))
PlotInteraction <- function(model, interactingVariables) {
interactionEffect <- .ComputeInteractionEffect(model, interactingVariables)
.PlotInteractionSurface(interactionEffect, interactingVariables)
}
.PlotInteractionSurface <- function(interactionEffect, interactingVariables) {
xAxis <- .GetXAxis(interactionEffect)
yAxis <- .GetYAxis(interactionEffect)
marginalEffect <- .GetMarginalEffect(interactionEffect)
.PlotlySurface(xAxis, yAxis, marginalEffect, interactingVariables)
}
.PlotlySurface <- function(xAxis, yAxis, marginalEffect, variableNames) {
plotlyBlueprint <- .GetPlotlyBlueprint(variableNames)
plot_ly(x=xAxis, y=yAxis, z=marginalEffect, type="surface") %>%
layout(scene=plotlyBlueprint)
}
.GetPlotlyBlueprint <- function(variableNames) {
numberOfAxes <- length(variableNames)
plotlyBlueprint <- list()
nameOfAxes <- c("xaxis", "yaxis", "zaxis")
for(index in 1:numberOfAxes) {
plotlyBlueprint[[nameOfAxes[index]]] <-
.GetAxisBlueprint(variableNames[index])
}
return(plotlyBlueprint)
}
.GetAxisBlueprint <- function(variableName) {
list(
title = variableName
)
}
.ComputeInteractionEffect <- function(model, interactingVariables) {
return(gbm:::plot.gbm(model, i.var=interactingVariables, continuous.resolution=200, return.grid=TRUE))
}
.GetXAxis <- function(interactionEffect) {
numberOfPointsPerAxis <- sqrt(length(interactionEffect[[3]]))
xAxisIndex <- 1:numberOfPointsPerAxis
xAxisValues <- interactionEffect[[1]][xAxisIndex]
return(xAxisValues)
}
.GetYAxis <- function(interactionEffect) {
numberOfPointsPerAxis <- sqrt(length(interactionEffect[[3]]))
yAxisIndex <- (0:(numberOfPointsPerAxis-1))*numberOfPointsPerAxis + 1
yAxisValues <- interactionEffect[[2]][yAxisIndex]
return(yAxisValues)
}
.GetMarginalEffect <- function(interactionEffect) {
numberOfPointsPerAxis <- sqrt(length(interactionEffect[[3]]))
marginalEffect <- matrix(interactionEffect[[3]],
nrow=numberOfPointsPerAxis,
ncol=numberOfPointsPerAxis,
byrow=TRUE)
return(marginalEffect)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment