Skip to content

Instantly share code, notes, and snippets.

@apoorvalal
Created March 18, 2022 17:23
Show Gist options
  • Save apoorvalal/d87429cc75f71904bb0a9de160cf668e to your computer and use it in GitHub Desktop.
Save apoorvalal/d87429cc75f71904bb0a9de160cf668e to your computer and use it in GitHub Desktop.
thompson sampling minimal example
rm(list = ls())
libreq(data.table, ggplot2)
set.seed(42)
# %%
thompson = function(n, K, reward_probs){
# init choices and reward vectors
choices <- rewards <- rep(NA, n)
# n+1 X K*2 matrix of S and F counts successes stored in first K, failures in next K
s_f = matrix(NA, nrow = n+1, K * 2) # +1 to accommodate last update step
s_f[1, ] = rep(1, K * 2) # initialize priors to 1s
# choice matrix - cumulative pull counts
cumul_choices = matrix(0, n+1, K)
for(t in 1:n){
# init posterior draws
θ = rep(NA, K)
for (k in 1:K) θ[k] = rbeta(1, s_f[t, k], s_f[t, K + k])
# choose argmax
choice = which.max(θ); choices[t] = choice
# pull arm, update rewards
reward = rbinom(1, 1, reward_probs[choice]); rewards[t] = reward
# update parameters - init zero
ru = rep(0, K * 2)
# for chosen arm
# success # failure
ru[choice] = reward; ru[choice + K] = 1 - reward
s_f[t+1, ] = s_f[t, ] + ru
# update cumulative choice count
cu = rep(0, K); cu[choice] = 1
cumul_choices[t+1, ] = cumul_choices[t, ] + cu
}
res = list(
successes = s_f[, 1:K],
failures = s_f[, (K+1):ncol(s_f)],
choices = choices,
rewards = rewards,
cumulative_counts = cumul_choices
)
}
# %%
armfig = \(n, r){
sim = thompson(n, length(r), r)
cumul_pulls = data.table(t = 1:(n+1), sim$cumulative_counts)
colnames(cumul_pulls)[-1] = r
pldf = cumul_pulls |> melt(id = 't')
ggplot(pldf, aes(t, value, colour = variable, group = variable)) +
geom_line() + lal_plot_theme()
}
# %%
armfig(1000L, c(0.14, 0.2, 0.16))
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment