Multi-Armed Bandit Policy simulated and animated with R. See: https://pavlov.tech/2019/03/02/animated-multi-armed-bandit-policies/. Generated using https://github.com/Nth-iteration-labs/contextual and https://yihui.name/animation/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(contextual) | |
library(data.table) | |
library(animation) | |
## 1. Bandit Simulation --------------------------------------------------------------------------- | |
# Run a simulation that saves the policy's theta values | |
policy <- EpsilonGreedyPolicy$new(epsilon = 0.1) | |
bandit <- BasicBernoulliBandit$new(weights = c(0.4, 0.5, 0.3)) | |
agent <- Agent$new(policy,bandit, "EG") | |
simulator <- Simulator$new(agents = agent, | |
horizon = 100, | |
set_seed = 8, | |
save_theta = TRUE, | |
simulations = 1) | |
hist <- simulator$run() | |
# retrieve saved parameter values | |
td <- hist$get_theta("EG",to_numeric_matrix = TRUE) | |
## 2. Bandit Animation ---------------------------------------------------------------------------- | |
# Create color matrix based on theta values | |
color_matrix <- matrix(c("gray","gray","gray"),nrow(td),3) | |
color_matrix[cbind(which(td[,"exploit"]==1),td[as.logical(td[,"exploit"]),"choice"])] <- "green" | |
color_matrix[cbind(which(td[,"exploit"]==0),td[!as.logical(td[,"exploit"]),"choice"])] <- "red" | |
colnames(color_matrix) <- c("C1","C2","C3") | |
cm <- as.data.frame(color_matrix, stringsAsFactors = FALSE) | |
message("Starting compilation of animation") | |
library(animation) | |
animation::ani.options(interval = 0.3, ani.width = 450, ani.height = 400, verbose = FALSE) | |
saveHTML({ | |
animation::ani.options(interval = 0.05) | |
for (i in 1:100) { | |
par(mar = c(5,4,4,9)) | |
barplot(c(td[i,"mean1"]+0.015, td[i,"mean2"]+0.015, td[i,"mean3"]+0.015), | |
ylim = c(0,1.015), | |
ylab = "Average Reward", | |
xlab = "Arm", | |
main = paste0("EpsilonGreedy\nt = ", | |
sprintf("%03d", i), | |
" | choice = ", | |
hist$data$choice[[i]], | |
" | reward = ", | |
hist$data$reward[[i]], "\n"), | |
names.arg = c("1", "2", "3"), | |
col = c(cm[i,"C1"], cm[i,"C2"], cm[i,"C3"]) | |
) | |
box() | |
axis(side = 1, at = c(0.7,1.9,3.1), labels = FALSE) | |
legend("bottomright",bty ="n",xpd = TRUE, inset=c(-0.42, 0), | |
title = "Choice", | |
legend = c("Exploiting", "Exploring"), fill = c("green", "red")) | |
ani.pause() | |
} | |
}, htmlfile = "index.html", img.name = "eg", navigator = FALSE, imgdir = "eg", | |
autoplay = FALSE, single.opts = "'theme': 'light', 'utf8': false,'controls': | |
['first', 'previous', 'play', 'next', 'last', 'loop','speed']") | |
message("Completed animation") | |
invisible(tryCatch(dev.off(), error=function(e){})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(contextual) | |
library(data.table) | |
library(animation) | |
## 1. Bandit Simulation --------------------------------------------------------------------------- | |
horizon <- 50 | |
weights <- c(0.5, 0.4, 0.6) | |
policy <- ThompsonSamplingPolicy$new() | |
bandit <- BasicBernoulliBandit$new(weights = weights) | |
agent <- Agent$new(policy,bandit, "TS") | |
simulator <- Simulator$new(agents = agent, | |
horizon = horizon, | |
set_seed = 22, | |
save_theta = TRUE, | |
simulations = 1) | |
hist <- simulator$run() | |
td <- data.frame(hist$get_theta("TS",to_numeric_matrix = TRUE)) | |
td <- data.table(rbind(rep(1,ncol(td)),td)) | |
## 2. Bandit Animation ---------------------------------------------------------------------------- | |
plot_curves <- function(td,i) { | |
plot(NULL, xlim = c(0,1), ylim = c(0,6), ylab = "Density", xlab = "Theta", | |
main = paste0("ThompsonSampling\nt = ", | |
sprintf("%03d", trunc(i/2)+1), | |
" | choice = ", | |
hist$data$choice[[i]], | |
" | reward = ", | |
hist$data$reward[[i]], "\n")) | |
curve(dbeta(x,as.numeric(td[i,1]),as.numeric(td[i,4])), from=0, to=1, col="green", add = TRUE) | |
curve(dbeta(x,as.numeric(td[i,2]),as.numeric(td[i,5]))-0.05, from=0, to=1, col="red", add = TRUE) | |
curve(dbeta(x,as.numeric(td[i,3]),as.numeric(td[i,6]))-0.1, from=0, to=1, col="blue", add = TRUE) | |
legend("bottomright",bty ="n",xpd = TRUE, inset=c(-0.42, 0), title = "Arms", | |
legend = c("Arm 1", "Arm 2", "Arm 3"), fill = c("green", "red", "blue")) | |
} | |
plot_lines <- function(td,i) { | |
abline(v=td[i,7],col="green") | |
abline(v=td[i,8],col="red") | |
abline(v=td[i,9],col="blue") | |
} | |
message("Starting compilation of animation") | |
library(animation) | |
animation::ani.options(interval = 0.06, ani.width = 450, ani.height = 400, verbose = FALSE) | |
saveHTML({ | |
for (i in seq(from = 1, to = horizon)) { | |
par(mar = c(5,4,4,9)) | |
plot_curves(td,i) | |
ani.pause() | |
plot_curves(td,i) | |
plot_lines(td,i+1) | |
ani.pause() | |
} | |
}, htmlfile = "index.html", img.name = "ts", navigator = FALSE, imgdir = "ts", | |
autoplay = FALSE, single.opts = "'theme': 'light', 'utf8': false,'controls': | |
['first', 'previous', 'play', 'next', 'last', 'loop','speed']") | |
message("Completed animation") | |
invisible(tryCatch(dev.off(), error=function(e){})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(contextual) | |
library(data.table) | |
library(animation) | |
## 1. Bandit Simulation --------------------------------------------------------------------------- | |
horizon <- 50 | |
weights <- c(0.5, 0.1, 0.8, 0.2, 0.4) | |
policy <- UCB1Policy$new() | |
bandit <- BasicBernoulliBandit$new(weights = weights) | |
agent <- Agent$new(policy,bandit, "UCB1") | |
simulator <- Simulator$new(agents = agent, | |
horizon = horizon, | |
save_theta = TRUE, | |
simulations = 1) | |
hist <- simulator$run() | |
td <- data.table(hist$get_theta("UCB1",to_numeric_matrix = TRUE)) | |
## 2. Bandit Animation ---------------------------------------------------------------------------- | |
# How many arms | |
k <- hist$data$k[1] | |
# Create matrix where chosen arm is green when reward 1, red when zero | |
color_matrix <- matrix(rep("gray",k),nrow(td),k) | |
color_matrix[cbind(which(td[,"reward"]==1), | |
as.numeric(td[ as.logical(td[,"reward"][[1]]), "choice"][[1]]))] <- "green" | |
color_matrix[cbind(which(td[,"reward"]==0), | |
as.numeric(td[!as.logical(td[,"reward"][[1]]), "choice"][[1]]))] <- "red" | |
colnames(color_matrix) <- c(paste0("C", 1:k)) | |
cm <- as.data.frame(color_matrix, stringsAsFactors = FALSE) | |
# The first k repeats are random | |
td$t[1:k] <- k | |
# Calculate standard deviation, UCB1 style | |
sd <- sqrt((2*log(td$t)) / td[,1:k]) | |
sd[is.na(sd) | sapply(sd, is.infinite)] <- 0 | |
x <- c(1:k) | |
plot_bandit <- function(x, y, col, i) { | |
plot(x, y, pch=15, cex = 2.5, ylim = c(-2.3,2.9), xlim = c(0.5,(k+0.5)), | |
ylab = "Average Reward", xlab = "Arm", col = col, | |
main = paste0("UCB1\nt = ", | |
sprintf("%03d", i), | |
" | choice = ", | |
ifelse(i>0,hist$data$choice[[i]],0), | |
" | reward = ", | |
ifelse(i>0,hist$data$reward[[i]],0), "\n")) | |
legend("bottomright",bty ="n",xpd = TRUE, inset=c(-0.42, 0), | |
title = "Choice", legend = c("Rewarded 1", "Rewarded 0"), fill = c("green", "red")) | |
} | |
des = c("This is a silly example.\n\n", "You can describe it in more detail.", | |
"For example, bla bla...") | |
message("Starting compilation of animation") | |
library(animation) | |
animation::ani.options(interval = 0.06, ani.width = 450, ani.height = 400, verbose = FALSE) | |
saveHTML({ | |
par(mar = c(5,4,4,9)) | |
# before start simulation, all gray.. | |
plot_bandit(x,rep(0,k), rep("gray",k),0) | |
ani.pause() | |
# followed by first choice | |
plot_bandit(x,rep(0,k), unlist(cm[1, c(paste0("C", 1:k))]),1) | |
ani.pause() | |
# now run over all remaining steps | |
for (i in 1:(horizon-1)) { | |
y_means <- as.numeric(td[i,(1+k):(k+k)]) | |
y_sd <- as.numeric(sd[i,]) | |
maxsd <- max(y_means + y_sd) | |
maxmean <- max(y_means) | |
plot_bandit(x,y_means,unlist(cm[(i + 1), c(paste0("C", 1:k))]),i+1) | |
abline(h = maxsd+0.02, col = "gray", lty=2) | |
abline(h = maxmean+0.02, col = "gray", lty=3) | |
arrows(x, y_means + y_sd, x, y_means-y_sd, angle=90, code=3, length=0.1, lwd = 2, | |
col = unlist(cm[(i + 1), c(paste0("C", 1:k))])) | |
ani.pause() | |
} | |
}, htmlfile = "index.html", img.name = "ucb", navigator = FALSE, imgdir = "ucb", | |
autoplay = FALSE, single.opts = "'theme': 'light', 'utf8': false,'controls': | |
['first', 'previous', 'play', 'next', 'last', 'loop','speed']") | |
message("Completed animation") | |
invisible(tryCatch(dev.off(), error=function(e){})) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment