Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
XGBoost learns the Canadian Flag -- Demo
library(png)
library(ggplot2)
library(xgboost)
img <- readPNG("canada.png")
red <- img[,,2]
HEIGHT <- dim(red)[1]
WIDTH <- dim(red)[2]
ERROR_RATE <- 0.05
###
get_data_points <- function(N) {
x <- sample(1:WIDTH, N, replace = T)
y <- sample(1:HEIGHT, N, replace = T)
p <- red[cbind(y, x)]
p <- round(p)
flips <- sample(c(0, 1), N, replace = T,
prob = c(ERROR_RATE, 1 - ERROR_RATE))
p[flips == 1] <- 1 - p[flips == 1]
data.frame(x=as.numeric(x), y=as.numeric(y), p=p)
}
data <- get_data_points(7500)
###
plot_canada <- function(dataplot) {
dataplot$y <- -dataplot$y
dataplot$p <- as.factor(dataplot$p)
ggplot(dataplot, aes(x = x, y = y, color = p)) +
geom_point(size = 1) +
scale_x_continuous(limits = c(0, 240)) +
scale_y_continuous(limits = c(-120, 0)) +
theme_minimal() +
theme(panel.background = element_rect(fill='black')) +
theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank()) +
scale_color_manual(values = c("white", "red"))
}
plot_canada(data)
###
fit <- xgboost(data = matrix(c(data$x, data$y), ncol = 2), label = data$p,
nrounds = 1,
max_depth = 3)
fullimg <- expand.grid(x = as.numeric(1:WIDTH), y = as.numeric(1:HEIGHT))
fullimg$p <- predict(fit, newdata = matrix(c(fullimg$x, fullimg$y), ncol = 2))
fullimg$p <- as.numeric(fullimg$p > 0.5)
plot_canada(fullimg)
@luckytoilet

This comment has been minimized.

Copy link
Owner Author

@luckytoilet luckytoilet commented Jan 27, 2018

File canada.png required for this script:

canada

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment