Looking at the decision boundary a classifier generates can give us some geometric intuition about the decision rule a classifier uses and how this decision rule changes as the classifier is trained on more data.
gg_plot_boundary <- function(density, points, title = "") {
ggplot() +
## gg_sample(data = density, size = 1.5, alpha = 0.1, shape = 15) +
gg_sample(data = points) +
gg_density(data = density, z = optimal, breaks = c(0), linetype = 2) +
gg_density(data = density, z = fitted, breaks = c(0)) +
coord_fixed(expand = FALSE) +
xlim(min(density$x), max(density$y)) +
ylim(min(density$y), max(density$y))
}
##' Animate the evolution of a decision boundary as the sample size grows
##'
##' @param sample `data.frame`: the complete sample data; should have
##' columns `x`, `y`, and `class`
##' @param density `data.frame`: the density distribution of `x` and
##' `y`; should have columns `x`, `y`, and `optimal`, the contours
##' of the optimal decision distribution
##' @param delta `integer`: how many points to add at each step of the
##' animation
##' @param fit_and_predict `function(sample, density)`: fits a learner
##' to the sample data and returns its predictions on the density
animate_boundary <- function(sample, density, delta, fit_and_predict) {
## a data.frame with successive groups of `data` of size `delta`;
## `data` is randomized before sequencing
sequence_data <- function(data, delta) {
rows <- nrow(data)
n <- rows / delta
data <- data[sample(nrow(data)), ]
go <- function(i) {
h <- min(i * delta, rows)
bind_cols(
head(data, h),
group = rep.int(i, h))
}
sequenced <- bind_rows(lapply(1:n, go))
return(sequenced)
}
## Sequence the sample data
sample_sequenced <- sequence_data(sample, delta)
## Sequence the density data and attach predictions from the sample
density_sequenced <- sample_sequenced %>%
group_by(group) %>%
group_modify(~ fit_and_predict(.x, density)) %>%
ungroup()
## Define the animation
anim <- ggplot() +
## Plot the sample
geom_point(data = sample_sequenced,
aes(x = x, y = y,
color = factor(class),
shape = factor(class)),
size = 3,
alpha = 0.5) +
## Plot the optimal decision boundary
geom_contour(data = density_sequenced,
aes(x, y, z = optimal),
breaks = c(0),
color = "black",
size = 1,
linetype = 2) +
## Plot the fitted decision boundary
geom_contour(data = density_sequenced,
aes(x, y, z = fitted),
breaks = c(0.5),
color = "black",
size = 1) +
coord_fixed(expand = FALSE) +
xlim(min(density$x), max(density$y)) +
ylim(min(density$y), max(density$y)) +
theme_linedraw() +
theme(plot.title = element_text(hjust = 0.5, size = 20),
legend.position = "none",
axis.text.x = element_blank(),
axis.text.y = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank()) +
## Animate the sample and the fitted boundary
transition_manual(group)
anim <- animate(anim, renderer = gifski_renderer(),
width = 800, height = 800)
return(anim)
}
lattice_plot_boundary <- function(density, sample, title) {
fitted_class <- ifelse(density[, "fitted"] > 0, 1, 0)
## classes in the feature grid
lattice::xyplot(y ~ x, groups = fitted_class,
data = density,
cex = 1, pch = 20, alpha = 0.1,
aspect = 1) +
## the optimal decision boundary
lattice::contourplot(optimal ~ x + y,
data = density,
at = c(0),
labels = FALSE,
lwd = 3,
lty = 2,
aspect = 1,
main = title) +
## the fitted decision boundary
lattice::contourplot(fitted ~ x + y,
data = density,
at = c(0),
labels = FALSE,
lwd = 3) +
## the sample
lattice::xyplot(y ~ x, groups = class,
data = sample,
pch = 19, alpha = 0.5)
}
lattice::levelplot(p_0 ~ x + y,
contour = TRUE,
region = FALSE,
cuts = 10,
data = density_mvn,
aspect = 1) +
lattice::levelplot(p_1 ~ x + y,
contour = TRUE,
region = FALSE,
cuts = 10,
data = density_mvn)
lattice::levelplot(optimal ~ x + y,
data = density_mvn,
aspect = 1,
cuts = 20,
contour = TRUE) +
lattice::contourplot(optimal ~ x + y,
data = density_mvn,
at = c(0),
labels = FALSE,
lwd = 3)
lattice::levelplot(p_0 ~ x + y,
contour = TRUE,
region = FALSE,
cuts = 5,
data = density_mvn,
aspect = 1) +
lattice::levelplot(p_1 ~ x + y,
contour = TRUE,
region = FALSE,
cuts = 5,
data = density_mvn) +
lattice::contourplot(optimal ~ x + y,
data = density_mvn,
at = c(0),
labels = FALSE,
lwd = 3)
lattice::contourplot(optimal ~ x + y,
data = density_mvn,
at = c(0),
labels = FALSE,
lwd = 3,
aspect = 1) +
lattice::xyplot(y ~ x, groups = class, data = points)
fit_lm <- lm(class ~ x + y, data = sample_mvn)
pred_lm <- predict(fit_lm, newdata = density_mvn)
density_lm <- cbind(density_mvn, "fitted" = pred_lm - 0.5)
gg_plot_boundary(density_lm, sample_mvn, "Linear")
## class_lm <- ifelse(pred_lm > 0, 1, 0)
## confusion_lm <- table(density[, "class"],
## density[, "class_lm"],
## dnn = c("True", "Predicted"))
## fourfoldplot(confusion_lm, main = "Linear")
fit_mvn_glm <- glm(class ~ x + y, data = sample_mvn, family = binomial)
pred_mvn_glm <- predict(fit_mvn_glm, newdata = density_mvn, type = "response")
density_mvn_glm <- cbind(density_mvn, "fitted" = pred_mvn_glm - 0.5)
gg_plot_boundary(density_mvn_glm, sample_mvn, "Logistic")
fit_mix_glm <- glm(class ~ x + y, data = sample_mix, family = binomial)
pred_mix_glm <- predict(fit_mix_glm, newdata = density_mix, type = "response")
density_mix_glm <- cbind(density_mix, "fitted" = pred_mix_glm - 0.5)
gg_plot_boundary(density_mix_glm, sample_mix, "Logistic")
fit_and_predict_glm <- function(sample, density) {
fit_glm <- glm(class ~ x + y, data = sample, family = binomial)
pred_glm <- predict(fit_glm, newdata = density_mvn, type = "response")
density_glm <- cbind(density, fitted = pred_glm)
return(density_glm)
}
anim_mvn_glm <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_glm)
anim_save("/home/jovyan/work/bayeserror/glm_mvn.gif", animation = anim_mvn_glm)
anim_mix_glm <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_glm)
anim_save("/home/jovyan/work/bayeserror/glm_mix.gif", animation = anim_mix_glm)
fit_mvn_gam <- mgcv::gam(class ~ s(x, y), class = "bernoulli", data = sample_mvn)
pred_mvn_gam <- predict(fit_mvn_gam, newdata = density_mvn, type = "response")
density_mvn_gam <- cbind(density_mvn, "fitted" = as.numeric(pred_mvn_gam) - 0.5)
gg_plot_boundary(density_mvn_gam, sample_mvn, title = "GAM")
fit_mix_gam <- mgcv::gam(class ~ s(x, y), class = "bernoulli", data = sample_mix)
pred_mix_gam <- predict(fit_mix_gam, newdata = density_mix, type = "response")
density_mix_gam <- cbind(density_mix, "fitted" = as.numeric(pred_mix_gam) - 0.5)
gg_plot_boundary(density_mix_gam, sample_mix, title = "GAM")
fit_and_predict_gam <- function(sample, density) {
fit_gam <- mgcv::gam(class ~ s(x, y), class = "bernoulli", data = sample)
pred_gam <- predict(fit_gam, newdata = density, type = "response")
density_gam <- cbind(density, "fitted" = as.numeric(pred_gam))
return(density_gam)
}
anim_mvn_gam <- animate_boundary(sample_mvn, density_mvn, 100, fit_and_predict_gam)
anim_save("/home/jovyan/work/bayeserror/gam_mvn.gif")
anim_mix_gam <- animate_boundary(sample_mix, density_mix, 100, fit_and_predict_gam)
anim_save("/home/jovyan/work/bayeserror/gam_mix.gif")
fit_mvn_mars <- earth::earth(factor(class) ~ x + y,
data = sample_mvn,
glm = list(family = "binomial"))
pred_mvn_mars <- predict(fit_mvn_mars, newdata = density_mvn, type = "response")
density_mvn_mars <- cbind(density_mvn, "fitted" = as.numeric(pred_mvn_mars) - 0.5)
gg_plot_boundary(density_mvn_mars, sample_mvn, title = "MARS")
fit_mars <- earth::earth(factor(class) ~ x + y,
data = sample_mix,
glm = list(family = "binomial"))
pred_mars <- predict(fit_mars, newdata = density_mix, type = "response")
density_mars <- cbind(density_mix, "fitted" = as.numeric(pred_mars) - 0.5)
gg_plot_boundary(density_mars, sample_mix, title = "MARS")
fit_and_predict_mars <- function(sample, density) {
fit_mars <- earth::earth(factor(class) ~ x + y,
data = sample,
glm = list(family = "binomial"))
pred_mars <- predict(fit_mars, newdata = density, type = "response")
density_mars <- cbind(density, "fitted" = as.numeric(pred_mars))
density_mars
}
anim_mvn_mars <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_mars)
anim_save("/home/jovyan/work/bayeserror/mars_mvn.gif")
anim_mix_mars <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_mars)
anim_save("/home/jovyan/work/bayeserror/mars_mix.gif")
fit_pmars <- polspline::polymars(sample_mvn[["class"]],
sample_mvn[, c("x", "y")],
classify = TRUE)
pred_pmars <- predict(fit_pmars,
x = as.data.frame(density_mvn[, c("x", "y")]))
density_pmars <- cbind(density_mvn, "fitted" = pred_pmars[, 2] - 0.5)
gg_plot_boundary(density_pmars, sample_mvn, title = "PolyMARS")
fit_pmars <- polspline::polymars(sample_mix[["class"]],
sample_mix[, c("x", "y")],
classify = TRUE)
pred_pmars <- predict(fit_pmars,
x = as.data.frame(density_mix[, c("x", "y")]))
density_pmars <- cbind(density_mix, "fitted" = pred_pmars[, 2] - 0.5)
gg_plot_boundary(density_pmars, sample_mix, title = "PolyMARS")
fit_and_predict_polymars <- function(sample, density) {
fit_pmars <- polspline::polymars(sample[["class"]],
as.data.frame(sample[, c("x", "y")]),
classify = TRUE)
pred_pmars <- predict(fit_pmars,
x = as.data.frame(density[, c("x", "y")]))
density_pmars <- cbind(density, "fitted" = pred_pmars[, 2])
density_pmars
}
anim_mvn_pmars <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_polymars)
anim_save("/home/jovyan/work/bayeserror/pmars_mvn.gif")
anim_mix_pmars <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_polymars)
anim_save("/home/jovyan/work/bayeserror/pmars_mix.gif")
fit_lda <- MASS::lda(class ~ x + y, data = density_mvn)
pred_lda <- predict(fit_lda, newdata = density_mvn)
density_lda <- cbind(density_mvn, "fitted" = pred_lda$posterior[, "1"] - 0.5)
fit_qda <- MASS::qda(class ~ x + y, data = density_mvn)
pred_qda <- predict(fit_qda, newdata = density_mvn)
density_qda <- cbind(density_mvn, "fitted" = pred_qda$posterior[, "1"] - 0.5)
gg_plot_boundary(density_qda, sample_mvn, title = "QDA")
fit_qda_points <- MASS::qda(class ~ x + y, data = sample_mvn)
pred_qda_points <- predict(fit_qda_points, newdata = density_mvn)
density_qda_points <- cbind(density_mvn, "fitted" = pred_qda_points$posterior[, "1"] - 0.5)
gg_plot_boundary(density_qda_points, sample_mvn, title = "QDA")
fit_and_predict_qda <- function(sample, density) {
fit_qda <- MASS::qda(class ~ x + y, data = sample)
pred_qda <- predict(fit_qda, newdata = density)
density_qda <- cbind(density, "fitted" = pred_qda$posterior[, "1"])
density_qda
}
anim_mvn_qda <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_qda)
anim_save("/home/jovyan/work/bayeserror/qda_mvn.gif")
anim_mix_qda <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_qda)
anim_save("/home/jovyan/work/bayeserror/qda_mix.gif")
pred_nn <- class::knn(train = sample_mvn[, c("x", "y")],
cl = factor(sample_mvn[, "class"]),
test = density_mvn[, c("x", "y")],
k = 5)
density_nn <- cbind(density_mvn, "fitted" = as.integer(pred_nn) - 2)
gg_plot_boundary(density_nn, sample_mvn, title = "Nearest Neighbors")
pred_nn <- class::knn(train = sample_mix[, c("x", "y")],
cl = factor(sample_mix[, "class"]),
test = density_mix[, c("x", "y")],
k = 5)
density_nn <- cbind(density_mix, "fitted" = as.integer(pred_nn) - 2)
gg_plot_boundary(density_nn, sample_mvn, title = "Nearest Neighbors")
fit_and_predict_knn <- function(sample, density) {
pred_knn <- class::knn(train = sample[, c("x", "y")],
cl = factor(sample$class),
test = density[, c("x", "y")],
k = 5)
density_knn <- cbind(density,
fitted = as.integer(pred_knn) - 1.5)
return(density_knn)
}
anim_mvn_knn <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_knn)
anim_save("/home/jovyan/work/bayeserror/knn_mvn.gif", animation = anim_mvn_knn)
anim_mix_knn <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_knn)
anim_save("/home/jovyan/work/bayeserror/knn_mix.gif", animation = anim_mix_knn)
fit_kknn <- kknn::train.kknn(factor(class) ~ x + y,
data = sample_mvn,
kernel = "gaussian")
pred_kknn <- predict(fit_kknn, newdata = density_mvn, type = "prob")
density_kknn <- cbind(density_mvn, "fitted" = pred_kknn[, 2] - 0.5)
gg_plot_boundary(density_kknn, sample_mvn, title = "KKNN")
fit_svm_points <- kernlab::ksvm(factor(class) ~ x + y,
data = sample_mvn,
kernel = "rbfdot",
prob.model = TRUE)
pred_svm_points <- kernlab::predict(fit_svm_points,
newdata = density_mvn,
type = "probabilities")
density_svm_points <- cbind(density_mvn, "fitted" = pred_svm_points[, "1"] - 0.5)
gg_plot_boundary(density_svm_points, sample_mvn, title = "SVM")
fit_svm <- kernlab::ksvm(factor(class) ~ x + y,
data = sample_mix,
kernel = "rbfdot",
prob.model = TRUE)
pred_svm <- kernlab::predict(fit_svm,
newdata = density_mix,
type = "probabilities")
density_svm <- cbind(density_mix, "fitted" = pred_svm[, "1"] - 0.5)
gg_plot_boundary(density_svm, sample_mix, title = "SVM")
fit_and_predict_svm <- function(sample, density) {
fit_svm <- kernlab::ksvm(factor(class) ~ x + y,
data = sample,
kernel = "rbfdot",
prob.model = TRUE)
pred_svm <- kernlab::predict(fit_svm,
newdata = density,
type = "probabilities")
density_svm <- cbind(density, "fitted" = pred_svm[, "1"])
density_svm
}
anim_mvn_svm <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_svm)
anim_save("/home/jovyan/work/bayeserror/svm_mvn.gif")
anim_mix_svm <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_svm)
anim_save("/home/jovyan/work/bayeserror/svm_mix.gif")
fit_rpart_points <- rpart::rpart(class ~ x + y, data = sample_mvn, method = "class")
pred_rpart_points <- predict(fit_rpart_points, newdata = density_mvn)
density_rpart_points <- cbind(density_mvn, "fitted" = pred_rpart_points[, "1"] - 0.5)
gg_plot_boundary(density_rpart_points, sample_mvn, title = "Decision Tree")
fit_rpart <- rpart::rpart(class ~ x + y, data = sample_mix, method = "class")
pred_rpart <- predict(fit_rpart, newdata = density_mix)
density_rpart <- cbind(density_mix, "fitted" = pred_rpart[, "1"] - 0.5)
gg_plot_boundary(density_rpart, sample_mix, title = "Decision Tree")
fit_and_predict_rpart <- function(sample, density) {
fit_rpart <- rpart::rpart(class ~ x + y, data = sample, method = "class")
pred_rpart <- predict(fit_rpart, newdata = density)
density_rpart <- cbind(density, "fitted" = pred_rpart[, "1"])
density_rpart
}
anim_mvn_rpart <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_rpart)
anim_save("/home/jovyan/work/bayeserror/rpart_mvn.gif")
anim_mix_rpart <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_rpart)
anim_save("/home/jovyan/work/bayeserror/rpart_mix.gif")
fit_rf <- ranger::ranger(factor(class) ~ x + y,
data = sample_mvn,
probability = TRUE)
pred_rf <- predict(fit_rf, data = density_mvn)
density_rf <- cbind(density_mvn, "fitted" = pred_rf$predictions[, "1"] - 0.5)
gg_plot_boundary(density_rf, sample_mvn, title = "Random Forest")
fit_rf <- ranger::ranger(factor(class) ~ x + y,
data = sample_mix,
probability = TRUE)
pred_rf <- predict(fit_rf, data = density_mix)
density_rf <- cbind(density_mix, "fitted" = pred_rf$predictions[, "1"] - 0.5)
gg_plot_boundary(density_rf, sample_mix, title = "Random Forest")
fit_and_predict_rf <- function(sample, density) {
fit_rf <- ranger::ranger(factor(class) ~ x + y,
data = sample,
probability = TRUE)
pred_rf <- predict(fit_rf, data = density)
density_rf <- cbind(density, "fitted" = pred_rf$predictions[, "1"])
return(density_rf)
}
anim_mvn_rf <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_rf)
anim_save("/home/jovyan/work/bayeserror/rf._mvn.gif", animation = anim_mvn_rf)
anim_mix_rf <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_rf)
anim_save("/home/jovyan/work/bayeserror/rf_mix.gif", animation = anim_mix_rf)
fit_gbm <- gbm::gbm(class ~ x + y,
data = sample_mvn,
n.trees = 100,
distribution = "bernoulli")
pred_gbm <- predict(fit_gbm,
n.trees = 100,
newdata = density_mvn,
type = "response")
density_gbm <- cbind(density_mvn, "fitted" = pred_gbm - 0.5)
gg_plot_boundary(density_gbm, sample_mvn, title = "Boosted Trees")
fit_gbm <- gbm::gbm(class ~ x + y,
data = sample_mix,
n.trees = 500,
distribution = "bernoulli")
pred_gbm <- predict(fit_gbm,
n.trees = 500,
newdata = density_mix,
type = "response")
density_gbm <- cbind(density_mix, "fitted" = pred_gbm-0.5)
gg_plot_boundary(density_gbm, sample_mix, title = "Boosted Trees")
set.seed(31415)
sample_xg <- xgboost::xgb.DMatrix(
as.matrix(sample_mvn[, c("x", "y")]),
label = as.numeric(sample_mvn$class))
xgcv <- xgboost::xgb.cv(data = sample_xg,
nrounds = 50,
early_stopping_rounds = 3,
nfold = 5,
objective = "binary:logistic")
fit_xg <- xgboost::xgboost(data = sample_xg,
nrounds = xgcv$best_iteration,
objective = "binary:logistic")
pred_xg <- predict(fit_xg, newdata = as.matrix(density_mvn[, c("x", "y")]))
density_xg <- cbind(density_mvn, "fitted" = pred_xg - 0.5)
gg_plot_boundary(density_xg, sample_mvn, title = "xgboost")
set.seed(31415)
sample_xg <- xgboost::xgb.DMatrix(
as.matrix(sample_mix[, c("x", "y")]),
label = as.numeric(sample_mix$class))
xgcv <- xgboost::xgb.cv(data = sample_xg,
nrounds = 50,
early_stopping_rounds = 3,
nfold = 5,
objective = "binary:logistic")
fit_xg <- xgboost::xgboost(data = sample_xg,
nrounds = xgcv$best_iteration,
objective = "binary:logistic")
pred_xg <- predict(fit_xg, newdata = as.matrix(density_mix[, c("x", "y")]))
density_xg <- cbind(density_mix, "fitted" = pred_xg - 0.5)
gg_plot_boundary(density_xg, sample_mix, title = "xgboost")
fit_and_predict_xgboost <- function(sample, density) {
set.seed(31415)
sample_xg <- xgboost::xgb.DMatrix(
as.matrix(sample[, c("x", "y")]),
label = as.numeric(sample$class))
xgcv <- xgboost::xgb.cv(data = sample_xg,
nrounds = 50,
early_stopping_rounds = 3,
nfold = 5,
objective = "binary:logistic",
verbose = 0)
fit_xg <- xgboost::xgboost(data = sample_xg,
nrounds = xgcv$best_iteration,
objective = "binary:logistic",
verbose = 0)
pred_xg <- predict(fit_xg, newdata = as.matrix(density[, c("x", "y")]))
density_xg <- cbind(density, "fitted" = pred_xg)
return(density_xg)
}
anim_mvn_xgb <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_xgboost)
anim_save("/home/jovyan/work/bayeserror/xgboost_mvn.gif")
anim_mix_xgb <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_xgboost)
anim_save("/home/jovyan/work/bayeserror/xgboost_mix.gif")
set.seed(31415)
fit_nn <- nnet::nnet(factor(class) ~ x + y,
data = sample_mvn,
size = 4,
decay = 0.01,
rang = 0.3,
maxit = 200)
pred_nn <- predict(fit_nn, newdata = density_mvn, type = "raw")
density_nn <- cbind(density_mvn, "fitted" = pred_nn - 0.5)
gg_plot_boundary(density_nn, sample_mvn, title = "Neural Network")
set.seed(31415)
fit_nn <- nnet::nnet(factor(class) ~ x + y,
data = sample_mix,
size = 4,
decay = 0.01,
rang = 0.3,
maxit = 200)
pred_nn <- predict(fit_nn, newdata = density_mix, type = "raw")
density_nn <- cbind(density_mix, "fitted" = pred_nn - 0.5)
gg_plot_boundary(density_nn, sample_mix, title = "Neural Network")
fit_and_predict_nn <- function(sample, density, seed = 31415) {
set.seed(seed)
fit_nn <- nnet::nnet(factor(class) ~ x + y,
data = sample,
size = 4,
decay = 0.01,
rang = 0.3,
maxit = 200,
trace = FALSE)
pred_nn <- predict(fit_nn, newdata = density, type = "raw")
density_nn <- cbind(density, "fitted" = pred_nn)
density_nn
}
anim_mvn_nn <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_nn)
anim_save("/home/jovyan/work/bayeserror/nn_mvn.gif")
anim_mix_n <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_nn)
anim_save("/home/jovyan/work/bayeserror/nn_mix.gif")
set.seed(31415)
fit_elm <- elmNNRcpp::elm_train(x = as.matrix(sample_mvn[, c("x", "y")]),
y = elmNNRcpp::onehot_encode(sample_mvn[["class"]]),
nhid = 10,
actfun = "sig")
pred_elm <- elmNNRcpp::elm_predict(fit_elm,
as.matrix(density_mvn[, c("x", "y")]))
density_elm <- cbind(density_mvn, "fitted" = pred_elm[, 1] - 0.5)
gg_plot_boundary(density_elm, sample_mvn, title = "ELM")
fit_and_predict_elm <- function(sample, density) {
set.seed(31415)
fit_elm <- elmNNRcpp::elm_train(x = as.matrix(sample[, c("x", "y")]),
y = elmNNRcpp::onehot_encode(sample[["class"]]),
nhid = 10,
actfun = "sig")
pred_elm <- elmNNRcpp::elm_predict(fit_elm,
as.matrix(density[, c("x", "y")]))
density_elm <- cbind(density, "fitted" = pred_elm[, 1])
return(density_elm)
}
anim_mvn_elm <- animate_boundary(sample_mvn, density_mvn, 10, fit_and_predict_elm)
anim_save("/home/jovyan/work/bayeserror/elm_mvn.gif")
anim_mix_elm <- animate_boundary(sample_mix, density_mix, 10, fit_and_predict_elm)
anim_save("/home/jovyan/work/bayeserror/elm_mix.gif")