Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexpghayes/7befc42976d98924d6df49a231c69088 to your computer and use it in GitHub Desktop.
Save alexpghayes/7befc42976d98924d6df49a231c69088 to your computer and use it in GitHub Desktop.
library(nnet)
library(pdp)
library(tidyverse)
library(broom)
augment.multinom <- function(object, newdata) {
newdata <- as_tibble(newdata)
class_probs <- predict(object, newdata, type = "prob")
bind_cols(newdata, as_tibble(class_probs))
}
# the outcome needs to be a factor object
is.factor(iris$Species) # should be TRUE
data <- iris
fit <- multinom(Species ~ ., data, trace = FALSE)
fit
partial_dependence <- function(predictor) {
var <- ensym(predictor)
x_s <- select(data, !!var)
x_c <- select(data, -!!var)
grid <- crossing(x_s, x_c)
augment(fit, grid) %>%
gather(class, prob, setosa, versicolor, virginica) %>%
group_by(class, !!var) %>%
summarize(marginal_prob = mean(prob))
}
all_dependencies <- colnames(iris)[1:4] %>%
map_dfr(partial_dependence) %>%
gather(feature, feature_value, -class, -marginal_prob) %>%
na.omit()
all_dependencies
all_dependencies %>%
ggplot(aes(feature_value, marginal_prob, color = class)) +
geom_line(size = 1) +
facet_wrap(vars(feature), scales = "free_x") +
scale_color_viridis_d() +
labs(title = "Partial dependence plots for all features",
y = "Marginal probability of class",
x = "Value of feature") +
theme_minimal()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment