Skip to content

Instantly share code, notes, and snippets.

@hoxo-m
Created March 25, 2019 09:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hoxo-m/8cf68fd4de51aeacabbef6f42564f033 to your computer and use it in GitHub Desktop.
Save hoxo-m/8cf68fd4de51aeacabbef6f42564f033 to your computer and use it in GitHub Desktop.
AdaBoost のアニメーション
f <- function(X) {
X <- as.matrix(X)
limit <- qchisq(0.5, df = ncol(X))
apply(X, 1, function(row) {
if(sum(row^2) > limit) 1 else -1
})
}
D <- 2
M <- 200
N <- 2000
set.seed(314)
X <- matrix(rnorm(N * D), nrow = N)
y <- f(X)
plot(X[,1], X[,2], col=y+2, pch=19)
library(rpart)
discrete_adaboost <- function(X, y, M) {
N <- nrow(X)
w <- rep(1/N, N)
alpha_m <- double(M)
Gm <- vector("list", length = M)
for (i in seq_len(M)) {
Gm[[i]] <- rpart(factor(y) ~ X, weights = w, maxdepth = 1)
y_hat <- predict(Gm[[i]], type = "class")
err <- y != y_hat
err_m <- sum(err * w) / sum(w)
alpha_m[i] <- log((1 - err_m) / err_m)
w <- w * exp(alpha_m[i] * err)
}
list(alpha_m = alpha_m, Gm = Gm)
}
res <- discrete_adaboost(X, y, M = M)
predict_DA <- function(DA, n_iter) {
pred <- vector("list", n_iter)
for (i in seq_len(n_iter)) {
pred[[i]] <- DA$alpha_m[[i]] *
as.integer(as.character(predict(DA$Gm[[i]], type = "class")))
}
ifelse(rowSums(data.frame(pred)) > 0, 1, -1)
}
library(ggplot2)
df <- data.frame(X, y = factor(y))
g1 <- ggplot(df, aes(X1, X2)) +
geom_point(aes(color = y)) +
xlab(NULL) + ylab(NULL) +
scale_color_discrete(guide=FALSE) +
ggtitle("正解")
library(animation)
library(cowplot)
animation::saveGIF({
pb <- txtProgressBar(0, M, style = 3)
count <- 30
for (i in seq_len(M)) {
setTxtProgressBar(pb, i)
pred <- predict_DA(res, i)
df <- data.frame(X, y = factor(pred))
g2 <- ggplot(df, aes(X1, X2)) +
geom_point(aes(color = y)) +
xlab(NULL) + ylab(NULL) +
scale_color_discrete(guide=FALSE) +
ggtitle(glue::glue("AdaBoost (iter = {i})"))
for (j in 1:count) {
print(cowplot::plot_grid(g1, g2))
}
count <- if (i < 10) 10 else if (i < 20) 5 else 1
}
for (i in 1:50) {
print(cowplot::plot_grid(g1, g2))
}
}, movie.name = "AdaBoostAnimation2.gif", interval = 0.1,
ani.width = 600, ani.height = 320)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment