Skip to content

Instantly share code, notes, and snippets.

@ryanholbrook
Created January 18, 2020 13:45
Show Gist options
  • Save ryanholbrook/220f39a669a870f4d312e46b67302106 to your computer and use it in GitHub Desktop.
Save ryanholbrook/220f39a669a870f4d312e46b67302106 to your computer and use it in GitHub Desktop.
R code for plotting and animating the decision boundaries

Classifiers

Introduction

Looking at the decision boundary a classifier generates can give us some geometric intuition about the decision rule a classifier uses and how this decision rule changes as the classifier is trained on more data.

Plotting Functions

gg_plot_boundary <- function(density, points, title = "") {
    ggplot() +
    ## gg_sample(data = density, size = 1.5, alpha = 0.1, shape = 15) +
    gg_sample(data = points) +
    gg_density(data = density, z = optimal, breaks = c(0), linetype = 2) +
    gg_density(data = density, z = fitted, breaks = c(0)) +
    coord_fixed(expand = FALSE) +
    xlim(min(density$x), max(density$y)) +
    ylim(min(density$y), max(density$y))
}

Animation

##' Animate the evolution of a decision boundary as the sample size grows
##'
##' @param sample `data.frame`: the complete sample data; should have
##'     columns `x`, `y`, and `class`
##' @param density `data.frame`: the density distribution of `x` and
##'     `y`; should have columns `x`, `y`, and `optimal`, the contours
##'     of the optimal decision distribution
##' @param delta `integer`: how many points to add at each step of the
##'     animation
##' @param fit_and_predict `function(sample, density)`: fits a learner
##'     to the sample data and returns its predictions on the density
animate_boundary <- function(sample, density, delta, fit_and_predict) {
    ## a data.frame with successive groups of `data` of size `delta`;
    ## `data` is randomized before sequencing
    sequence_data <- function(data, delta) {
        rows <- nrow(data)
        n <- rows / delta
        data <- data[sample(nrow(data)), ]
        go <- function(i) {
            h <- min(i * delta, rows)
            bind_cols(
                head(data, h),
                group = rep.int(i, h))
        }
        sequenced <- bind_rows(lapply(1:n, go))
        return(sequenced)
    }
    ## Sequence the sample data
    sample_sequenced <- sequence_data(sample, delta)
    ## Sequence the density data and attach predictions from the sample
    density_sequenced <- sample_sequenced %>%
        group_by(group) %>%
        group_modify(~ fit_and_predict(.x, density)) %>%
        ungroup()
    ## Define the animation
    anim <- ggplot() +
        ## Plot the sample
        geom_point(data = sample_sequenced,
                   aes(x = x, y = y,
                       color = factor(class),
                       shape = factor(class)),
                   size = 3,
                   alpha = 0.5) +
        ## Plot the optimal decision boundary
        geom_contour(data = density_sequenced,
                     aes(x, y, z = optimal),
                     breaks = c(0),
                     color = "black",
                     size = 1,
                     linetype = 2) +
        ## Plot the fitted decision boundary
        geom_contour(data = density_sequenced,
                     aes(x, y, z = fitted),
                     breaks = c(0.5),
                     color = "black",
                     size = 1) +
        coord_fixed(expand = FALSE) +
        xlim(min(density$x), max(density$y)) +
        ylim(min(density$y), max(density$y)) +
        theme_linedraw() +
        theme(plot.title = element_text(hjust = 0.5, size = 20),
              legend.position = "none",
              axis.text.x = element_blank(),
              axis.text.y = element_blank(),
              axis.title.x = element_blank(),
              axis.title.y = element_blank()) +
        ## Animate the sample and the fitted boundary
        transition_manual(group)
    anim <- animate(anim, renderer = gifski_renderer(),
                    width = 800, height = 800)
    return(anim)
}

Lattice plots

lattice_plot_boundary <- function(density, sample, title) {
    fitted_class <- ifelse(density[, "fitted"] > 0, 1, 0)
    ## classes in the feature grid
    lattice::xyplot(y ~ x, groups = fitted_class,
                    data = density,
                    cex = 1, pch = 20, alpha = 0.1,
                    aspect = 1) +
    ## the optimal decision boundary
    lattice::contourplot(optimal ~ x + y,
                         data = density,
                         at = c(0),
                         labels = FALSE,
                         lwd = 3,
                         lty = 2,
                         aspect = 1,
                         main = title) +
    ## the fitted decision boundary
    lattice::contourplot(fitted ~ x + y,
                         data = density,
                         at = c(0),
                         labels = FALSE,
                         lwd = 3) +
    ## the sample
    lattice::xyplot(y ~ x, groups = class,
                    data = sample,
                    pch = 19, alpha = 0.5)
}


lattice::levelplot(p_0 ~ x + y,
                   contour = TRUE,
                   region = FALSE,
                   cuts = 10,
                   data = density_mvn,
                   aspect = 1) +
    lattice::levelplot(p_1 ~ x + y,
                       contour = TRUE,
                       region = FALSE,
                       cuts = 10,
                       data = density_mvn)

lattice::levelplot(optimal ~ x + y,
                   data = density_mvn,
                   aspect = 1,
                   cuts = 20,
                   contour = TRUE) +
    lattice::contourplot(optimal ~ x + y,
                         data = density_mvn,
                         at = c(0),
                         labels = FALSE,
                         lwd = 3)

lattice::levelplot(p_0 ~ x + y,
                   contour = TRUE,
                   region = FALSE,
                   cuts = 5,
                   data = density_mvn,
                   aspect = 1) +
    lattice::levelplot(p_1 ~ x + y,
                       contour = TRUE,
                       region = FALSE,
                       cuts = 5,
                       data = density_mvn) +
    lattice::contourplot(optimal ~ x + y,
                         data = density_mvn,
                         at = c(0),
                         labels = FALSE,
                         lwd = 3)
    
lattice::contourplot(optimal ~ x + y,
                     data = density_mvn,
                     at = c(0),
                     labels = FALSE,
                     lwd = 3,
                     aspect = 1) +
    lattice::xyplot(y ~ x, groups = class, data = points)

Regression Models

Linear

fit_lm <- lm(class ~ x + y, data = sample_mvn)
pred_lm <- predict(fit_lm, newdata = density_mvn)
density_lm <- cbind(density_mvn, "fitted" = pred_lm - 0.5)

gg_plot_boundary(density_lm, sample_mvn, "Linear")

## class_lm <- ifelse(pred_lm > 0, 1, 0)

## confusion_lm <- table(density[, "class"],
##                    density[, "class_lm"],
##                    dnn = c("True", "Predicted"))

## fourfoldplot(confusion_lm, main = "Linear")

Logistic

fit_mvn_glm <- glm(class ~ x + y, data = sample_mvn, family = binomial)
pred_mvn_glm <- predict(fit_mvn_glm, newdata = density_mvn, type = "response")
density_mvn_glm <- cbind(density_mvn, "fitted" = pred_mvn_glm - 0.5)
gg_plot_boundary(density_mvn_glm, sample_mvn, "Logistic")

fit_mix_glm <- glm(class ~ x + y, data = sample_mix, family = binomial)
pred_mix_glm <- predict(fit_mix_glm, newdata = density_mix, type = "response")
density_mix_glm <- cbind(density_mix, "fitted" = pred_mix_glm - 0.5)
gg_plot_boundary(density_mix_glm, sample_mix, "Logistic")


fit_and_predict_glm <- function(sample, density) {
    fit_glm <- glm(class ~ x + y, data = sample, family = binomial)
    pred_glm <- predict(fit_glm, newdata = density_mvn, type = "response")
    density_glm <- cbind(density, fitted = pred_glm)
    return(density_glm)
}

anim_mvn_glm <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_glm)
anim_save("/home/jovyan/work/bayeserror/glm_mvn.gif", animation = anim_mvn_glm)

anim_mix_glm <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_glm)
anim_save("/home/jovyan/work/bayeserror/glm_mix.gif", animation = anim_mix_glm)

Logistic GAM

fit_mvn_gam <- mgcv::gam(class ~ s(x, y), class = "bernoulli", data = sample_mvn)
pred_mvn_gam <- predict(fit_mvn_gam, newdata = density_mvn, type = "response")
density_mvn_gam <- cbind(density_mvn, "fitted" = as.numeric(pred_mvn_gam) - 0.5)
gg_plot_boundary(density_mvn_gam, sample_mvn, title = "GAM")

fit_mix_gam <- mgcv::gam(class ~ s(x, y), class = "bernoulli", data = sample_mix)
pred_mix_gam <- predict(fit_mix_gam, newdata = density_mix, type = "response")
density_mix_gam <- cbind(density_mix, "fitted" = as.numeric(pred_mix_gam) - 0.5)
gg_plot_boundary(density_mix_gam, sample_mix, title = "GAM")


fit_and_predict_gam <- function(sample, density) {
    fit_gam <- mgcv::gam(class ~ s(x, y), class = "bernoulli", data = sample)
    pred_gam <- predict(fit_gam, newdata = density, type = "response")
    density_gam <- cbind(density, "fitted" = as.numeric(pred_gam))
    return(density_gam)
}

anim_mvn_gam <- animate_boundary(sample_mvn, density_mvn, 100, fit_and_predict_gam)
anim_save("/home/jovyan/work/bayeserror/gam_mvn.gif")
anim_mix_gam <- animate_boundary(sample_mix, density_mix, 100, fit_and_predict_gam)
anim_save("/home/jovyan/work/bayeserror/gam_mix.gif")

Splines and Smoothers

MARS

fit_mvn_mars <- earth::earth(factor(class) ~ x + y,
                         data = sample_mvn,
                         glm = list(family = "binomial"))
pred_mvn_mars <- predict(fit_mvn_mars, newdata = density_mvn, type = "response")
density_mvn_mars <- cbind(density_mvn, "fitted" = as.numeric(pred_mvn_mars) - 0.5)
gg_plot_boundary(density_mvn_mars, sample_mvn, title = "MARS")


fit_mars <- earth::earth(factor(class) ~ x + y,
                         data = sample_mix,
                         glm = list(family = "binomial"))
pred_mars <- predict(fit_mars, newdata = density_mix, type = "response")
density_mars <- cbind(density_mix, "fitted" = as.numeric(pred_mars) - 0.5)
gg_plot_boundary(density_mars, sample_mix, title = "MARS")


fit_and_predict_mars <- function(sample, density) {
    fit_mars <- earth::earth(factor(class) ~ x + y,
                             data = sample,
                             glm = list(family = "binomial"))
    pred_mars <- predict(fit_mars, newdata = density, type = "response")
    density_mars <- cbind(density, "fitted" = as.numeric(pred_mars))
    density_mars
}

anim_mvn_mars <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_mars)
anim_save("/home/jovyan/work/bayeserror/mars_mvn.gif")
anim_mix_mars <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_mars)
anim_save("/home/jovyan/work/bayeserror/mars_mix.gif")

Poly-MARS

fit_pmars <- polspline::polymars(sample_mvn[["class"]],
                                 sample_mvn[, c("x", "y")],
                                 classify = TRUE)
pred_pmars <- predict(fit_pmars,
                      x = as.data.frame(density_mvn[, c("x", "y")]))
density_pmars <- cbind(density_mvn, "fitted" = pred_pmars[, 2] - 0.5)
gg_plot_boundary(density_pmars, sample_mvn, title = "PolyMARS")


fit_pmars <- polspline::polymars(sample_mix[["class"]],
                                 sample_mix[, c("x", "y")],
                                 classify = TRUE)
pred_pmars <- predict(fit_pmars,
                      x = as.data.frame(density_mix[, c("x", "y")]))
density_pmars <- cbind(density_mix, "fitted" = pred_pmars[, 2] - 0.5)
gg_plot_boundary(density_pmars, sample_mix, title = "PolyMARS")



fit_and_predict_polymars <- function(sample, density) {
    fit_pmars <- polspline::polymars(sample[["class"]],
                                     as.data.frame(sample[, c("x", "y")]),
                                     classify = TRUE)
    pred_pmars <- predict(fit_pmars,
                          x = as.data.frame(density[, c("x", "y")]))
    density_pmars <- cbind(density, "fitted" = pred_pmars[, 2])
    density_pmars
}

anim_mvn_pmars <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_polymars)

anim_save("/home/jovyan/work/bayeserror/pmars_mvn.gif")

anim_mix_pmars <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_polymars)

anim_save("/home/jovyan/work/bayeserror/pmars_mix.gif")

Discriminant Analysis

Linear Discriminant Analysis

fit_lda <- MASS::lda(class ~ x + y, data = density_mvn)
pred_lda <- predict(fit_lda, newdata = density_mvn)
density_lda <- cbind(density_mvn, "fitted" = pred_lda$posterior[, "1"] - 0.5)

Quadratic Discriminant Analysis

fit_qda <- MASS::qda(class ~ x + y, data = density_mvn)
pred_qda <- predict(fit_qda, newdata = density_mvn)
density_qda <- cbind(density_mvn, "fitted" = pred_qda$posterior[, "1"] - 0.5)
gg_plot_boundary(density_qda, sample_mvn, title = "QDA")

fit_qda_points <- MASS::qda(class ~ x + y, data = sample_mvn)
pred_qda_points <- predict(fit_qda_points, newdata = density_mvn)
density_qda_points <- cbind(density_mvn, "fitted" = pred_qda_points$posterior[, "1"] - 0.5)
gg_plot_boundary(density_qda_points, sample_mvn, title = "QDA")


fit_and_predict_qda <- function(sample, density) {
    fit_qda <- MASS::qda(class ~ x + y, data = sample)
    pred_qda <- predict(fit_qda, newdata = density)
    density_qda <- cbind(density, "fitted" = pred_qda$posterior[, "1"])
    density_qda
}

anim_mvn_qda <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_qda)

anim_save("/home/jovyan/work/bayeserror/qda_mvn.gif")

anim_mix_qda <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_qda)

anim_save("/home/jovyan/work/bayeserror/qda_mix.gif")

Mixture Discriminant Analysis

Nearest Neighbors

pred_nn <- class::knn(train = sample_mvn[, c("x", "y")],
                             cl = factor(sample_mvn[, "class"]),
                             test = density_mvn[, c("x", "y")],
                             k = 5)
density_nn <- cbind(density_mvn, "fitted" = as.integer(pred_nn) - 2)
gg_plot_boundary(density_nn, sample_mvn, title = "Nearest Neighbors")


pred_nn <- class::knn(train = sample_mix[, c("x", "y")],
                      cl = factor(sample_mix[, "class"]),
                      test = density_mix[, c("x", "y")],
                      k = 5)
density_nn <- cbind(density_mix, "fitted" = as.integer(pred_nn) - 2)
gg_plot_boundary(density_nn, sample_mvn, title = "Nearest Neighbors")
fit_and_predict_knn <- function(sample, density) {
    pred_knn <- class::knn(train = sample[, c("x", "y")],
                           cl = factor(sample$class),
                           test = density[, c("x", "y")],
                           k = 5)
    density_knn <- cbind(density,
                         fitted = as.integer(pred_knn) - 1.5)
    return(density_knn)
}

anim_mvn_knn <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_knn)
anim_save("/home/jovyan/work/bayeserror/knn_mvn.gif", animation = anim_mvn_knn)

anim_mix_knn <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_knn)
anim_save("/home/jovyan/work/bayeserror/knn_mix.gif", animation = anim_mix_knn)

Kernel NN

fit_kknn <- kknn::train.kknn(factor(class) ~ x + y,
                             data = sample_mvn,
                             kernel = "gaussian")
pred_kknn <- predict(fit_kknn, newdata = density_mvn, type = "prob")
density_kknn <- cbind(density_mvn, "fitted" = pred_kknn[, 2] - 0.5)
gg_plot_boundary(density_kknn, sample_mvn, title = "KKNN")

Support Vector Machines

fit_svm_points <- kernlab::ksvm(factor(class) ~ x + y,
                                data = sample_mvn,
                                kernel = "rbfdot",
                                prob.model = TRUE)
pred_svm_points <- kernlab::predict(fit_svm_points,
                                    newdata = density_mvn,
                                    type = "probabilities")
density_svm_points <- cbind(density_mvn, "fitted" = pred_svm_points[, "1"] - 0.5)
gg_plot_boundary(density_svm_points, sample_mvn, title = "SVM")



fit_svm <- kernlab::ksvm(factor(class) ~ x + y,
                         data = sample_mix,
                         kernel = "rbfdot",
                         prob.model = TRUE)
pred_svm <- kernlab::predict(fit_svm,
                             newdata = density_mix,
                             type = "probabilities")
density_svm <- cbind(density_mix, "fitted" = pred_svm[, "1"] - 0.5)
gg_plot_boundary(density_svm, sample_mix, title = "SVM")

fit_and_predict_svm <- function(sample, density) {
    fit_svm <- kernlab::ksvm(factor(class) ~ x + y,
                             data = sample,
                             kernel = "rbfdot",
                             prob.model = TRUE)
    pred_svm <- kernlab::predict(fit_svm,
                             newdata = density,
                             type = "probabilities")
    density_svm <- cbind(density, "fitted" = pred_svm[, "1"])
    density_svm
}

anim_mvn_svm <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_svm)
anim_save("/home/jovyan/work/bayeserror/svm_mvn.gif")
anim_mix_svm <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_svm)
anim_save("/home/jovyan/work/bayeserror/svm_mix.gif")

Trees

Decision Trees

fit_rpart_points <- rpart::rpart(class ~ x + y, data = sample_mvn, method = "class")
pred_rpart_points <- predict(fit_rpart_points, newdata = density_mvn)
density_rpart_points <- cbind(density_mvn, "fitted" = pred_rpart_points[, "1"] - 0.5)
gg_plot_boundary(density_rpart_points, sample_mvn, title = "Decision Tree")

fit_rpart <- rpart::rpart(class ~ x + y, data = sample_mix, method = "class")
pred_rpart <- predict(fit_rpart, newdata = density_mix)
density_rpart <- cbind(density_mix, "fitted" = pred_rpart[, "1"] - 0.5)
gg_plot_boundary(density_rpart, sample_mix, title = "Decision Tree")

fit_and_predict_rpart <- function(sample, density) {
    fit_rpart <- rpart::rpart(class ~ x + y, data = sample, method = "class")
    pred_rpart <- predict(fit_rpart, newdata = density)
    density_rpart <- cbind(density, "fitted" = pred_rpart[, "1"])
    density_rpart
}

anim_mvn_rpart <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_rpart)
anim_save("/home/jovyan/work/bayeserror/rpart_mvn.gif")
anim_mix_rpart <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_rpart)
anim_save("/home/jovyan/work/bayeserror/rpart_mix.gif")

Bagged Trees

Random Forests

fit_rf <- ranger::ranger(factor(class) ~ x + y,
                         data = sample_mvn,
                         probability = TRUE)
pred_rf <- predict(fit_rf, data = density_mvn)
density_rf <- cbind(density_mvn, "fitted" = pred_rf$predictions[, "1"] - 0.5)
gg_plot_boundary(density_rf, sample_mvn, title = "Random Forest")


fit_rf <- ranger::ranger(factor(class) ~ x + y,
                         data = sample_mix,
                         probability = TRUE)
pred_rf <- predict(fit_rf, data = density_mix)
density_rf <- cbind(density_mix, "fitted" = pred_rf$predictions[, "1"] - 0.5)
gg_plot_boundary(density_rf, sample_mix, title = "Random Forest")



fit_and_predict_rf <- function(sample, density) {
    fit_rf <- ranger::ranger(factor(class) ~ x + y,
                             data = sample,
                             probability = TRUE)
    pred_rf <- predict(fit_rf, data = density)
    density_rf <- cbind(density, "fitted" = pred_rf$predictions[, "1"])
    return(density_rf)
}

anim_mvn_rf <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_rf)
anim_save("/home/jovyan/work/bayeserror/rf._mvn.gif", animation = anim_mvn_rf)

anim_mix_rf <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_rf)
anim_save("/home/jovyan/work/bayeserror/rf_mix.gif", animation = anim_mix_rf)

BART

Gradient Boosting

fit_gbm <- gbm::gbm(class ~ x + y,
                    data = sample_mvn,
                    n.trees = 100,
                    distribution = "bernoulli")
pred_gbm <- predict(fit_gbm,
                    n.trees = 100,
                    newdata = density_mvn,
                    type = "response")
density_gbm <- cbind(density_mvn, "fitted" = pred_gbm - 0.5)
gg_plot_boundary(density_gbm, sample_mvn, title = "Boosted Trees")


fit_gbm <- gbm::gbm(class ~ x + y,
                    data = sample_mix,
                    n.trees = 500,
                    distribution = "bernoulli")
pred_gbm <- predict(fit_gbm,
                    n.trees = 500,
                    newdata = density_mix,
                    type = "response")
density_gbm <- cbind(density_mix, "fitted" = pred_gbm-0.5)
gg_plot_boundary(density_gbm, sample_mix, title = "Boosted Trees")

xgboost

set.seed(31415)
sample_xg <- xgboost::xgb.DMatrix(
                          as.matrix(sample_mvn[, c("x", "y")]),
                          label = as.numeric(sample_mvn$class))
xgcv <- xgboost::xgb.cv(data = sample_xg,
                        nrounds = 50,
                        early_stopping_rounds = 3,
                        nfold = 5,
                        objective = "binary:logistic")
fit_xg <- xgboost::xgboost(data = sample_xg,
                           nrounds = xgcv$best_iteration,
                           objective = "binary:logistic")
pred_xg <- predict(fit_xg, newdata = as.matrix(density_mvn[, c("x", "y")]))
density_xg <- cbind(density_mvn, "fitted" = pred_xg - 0.5)
gg_plot_boundary(density_xg, sample_mvn, title = "xgboost")


set.seed(31415)
sample_xg <- xgboost::xgb.DMatrix(
                          as.matrix(sample_mix[, c("x", "y")]),
                          label = as.numeric(sample_mix$class))
xgcv <- xgboost::xgb.cv(data = sample_xg,
                        nrounds = 50,
                        early_stopping_rounds = 3,
                        nfold = 5,
                        objective = "binary:logistic")
fit_xg <- xgboost::xgboost(data = sample_xg,
                           nrounds = xgcv$best_iteration,
                           objective = "binary:logistic")
pred_xg <- predict(fit_xg, newdata = as.matrix(density_mix[, c("x", "y")]))
density_xg <- cbind(density_mix, "fitted" = pred_xg - 0.5)
gg_plot_boundary(density_xg, sample_mix, title = "xgboost")


fit_and_predict_xgboost <- function(sample, density) {
    set.seed(31415)
    sample_xg <- xgboost::xgb.DMatrix(
                              as.matrix(sample[, c("x", "y")]),
                              label = as.numeric(sample$class))
    xgcv <- xgboost::xgb.cv(data = sample_xg,
                            nrounds = 50,
                            early_stopping_rounds = 3,
                            nfold = 5,
                            objective = "binary:logistic",
                            verbose = 0)
    fit_xg <- xgboost::xgboost(data = sample_xg,
                               nrounds = xgcv$best_iteration,
                               objective = "binary:logistic",
                               verbose = 0)
    pred_xg <- predict(fit_xg, newdata = as.matrix(density[, c("x", "y")]))
    density_xg <- cbind(density, "fitted" = pred_xg)
    return(density_xg)
}

anim_mvn_xgb <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_xgboost)
anim_save("/home/jovyan/work/bayeserror/xgboost_mvn.gif")
anim_mix_xgb <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_xgboost)
anim_save("/home/jovyan/work/bayeserror/xgboost_mix.gif")

Neural Networks

Feedforward Perceptrons

set.seed(31415)
fit_nn <- nnet::nnet(factor(class) ~ x + y,
                     data = sample_mvn,
                     size = 4,
                     decay = 0.01,
                     rang = 0.3,
                     maxit = 200)
pred_nn <- predict(fit_nn, newdata = density_mvn, type = "raw")
density_nn <- cbind(density_mvn, "fitted" = pred_nn - 0.5)
gg_plot_boundary(density_nn, sample_mvn, title = "Neural Network")

set.seed(31415)
fit_nn <- nnet::nnet(factor(class) ~ x + y,
                     data = sample_mix,
                     size = 4,
                     decay = 0.01,
                     rang = 0.3,
                     maxit = 200)
pred_nn <- predict(fit_nn, newdata = density_mix, type = "raw")
density_nn <- cbind(density_mix, "fitted" = pred_nn - 0.5)
gg_plot_boundary(density_nn, sample_mix, title = "Neural Network")

fit_and_predict_nn <- function(sample, density, seed = 31415) {
    set.seed(seed)
    fit_nn <- nnet::nnet(factor(class) ~ x + y,
                         data = sample,
                         size = 4,
                         decay = 0.01,
                         rang = 0.3,
                         maxit = 200,
                         trace = FALSE)
    pred_nn <- predict(fit_nn, newdata = density, type = "raw")
    density_nn <- cbind(density, "fitted" = pred_nn)
    density_nn
}

anim_mvn_nn <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_nn)
anim_save("/home/jovyan/work/bayeserror/nn_mvn.gif")
anim_mix_n <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_nn)
anim_save("/home/jovyan/work/bayeserror/nn_mix.gif")

Extreme Learning Machines

set.seed(31415)
fit_elm <- elmNNRcpp::elm_train(x = as.matrix(sample_mvn[, c("x", "y")]),
                                y = elmNNRcpp::onehot_encode(sample_mvn[["class"]]),
                                nhid = 10,
                                actfun = "sig")
pred_elm <- elmNNRcpp::elm_predict(fit_elm,
                                   as.matrix(density_mvn[, c("x", "y")]))
density_elm <- cbind(density_mvn, "fitted" = pred_elm[, 1] - 0.5)
gg_plot_boundary(density_elm, sample_mvn, title = "ELM")


fit_and_predict_elm <- function(sample, density) {
    set.seed(31415)
    fit_elm <- elmNNRcpp::elm_train(x = as.matrix(sample[, c("x", "y")]),
                                    y = elmNNRcpp::onehot_encode(sample[["class"]]),
                                    nhid = 10,
                                    actfun = "sig")
    pred_elm <- elmNNRcpp::elm_predict(fit_elm,
                                       as.matrix(density[, c("x", "y")]))
    density_elm <- cbind(density, "fitted" = pred_elm[, 1])
    return(density_elm)
}

anim_mvn_elm <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_elm)
anim_save("/home/jovyan/work/bayeserror/elm_mvn.gif")

anim_mix_elm <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_elm)
anim_save("/home/jovyan/work/bayeserror/elm_mix.gif")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment