Skip to content

Instantly share code, notes, and snippets.

@stucchio
Last active April 3, 2022 21:17
Show Gist options
  • Save stucchio/5383015 to your computer and use it in GitHub Desktop.
Save stucchio/5383015 to your computer and use it in GitHub Desktop.
Beta-distribution Bandit
from numpy import *
from scipy.stats import beta
class BetaBandit(object):
def __init__(self, num_options=2, prior=(1.0,1.0)):
self.trials = zeros(shape=(num_options,), dtype=int)
self.successes = zeros(shape=(num_options,), dtype=int)
self.num_options = num_options
self.prior = prior
def add_result(self, trial_id, success):
self.trials[trial_id] = self.trials[trial_id] + 1
if (success):
self.successes[trial_id] = self.successes[trial_id] + 1
def get_recommendation(self):
sampled_theta = []
for i in range(self.num_options):
#Construct beta distribution for posterior
dist = beta(self.prior[0]+self.successes[i],
self.prior[1]+self.trials[i]-self.successes[i])
#Draw sample from beta distribution
sampled_theta += [ dist.rvs() ]
# Return the index of the sample with the largest value
return sampled_theta.index( max(sampled_theta) )
from beta_bandit import *
from numpy import *
from scipy.stats import beta
import random
theta = (0.25, 0.35)
def is_conversion(title):
if random.random() < theta[title]:
return True
else:
return False
conversions = [0,0]
trials = [0,0]
N = 100000
trials = zeros(shape=(N,2))
successes = zeros(shape=(N,2))
bb = BetaBandit()
for i in range(N):
choice = bb.get_recommendation()
trials[choice] = trials[choice]+1
conv = is_conversion(choice)
bb.add_result(choice, conv)
trials[i] = bb.trials
successes[i] = bb.successes
from pylab import *
subplot(211)
n = arange(N)+1
loglog(n, trials[:,0], label="title 0")
loglog(n, trials[:,1], label="title 1")
legend()
xlabel("Number of trials")
ylabel("Number of trials/title")
subplot(212)
semilogx(n, (successes[:,0]+successes[:,1])/n, label="CTR")
semilogx(n, zeros(shape=(N,))+0.35, label="Best CTR")
semilogx(n, zeros(shape=(N,))+0.30, label="Random chance CTR")
semilogx(n, zeros(shape=(N,))+0.25, label="Worst CTR")
axis([0,N,0.15,0.45])
xlabel("Number of trials")
ylabel("CTR")
legend()
show()
@jvcodell
Copy link

jvcodell commented Mar 26, 2018

In the iteration over time steps:

for i in range(N):
choice = bb.get_recommendation()
trials[choice] = trials[choice]+1
conv = is_conversion(choice)
bb.add_result(choice, conv)

trials[i] = bb.trials
successes[i] = bb.successes

Can you explain what the code, " trials[choice] = trials[choice]+1" does? the first dimension of trials is over the time horizon, isn't it? I'm confused why you do that, then set "trials[i] = bb.trials" later on. It seems like the " trials[i] = bb.trials" is all that's needed since you really just want to record what action has been tried in this time step... is that right?

Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment