Skip to content

Instantly share code, notes, and snippets.

@stucchio
Created February 7, 2014 18:55
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save stucchio/8869292 to your computer and use it in GitHub Desktop.
Save stucchio/8869292 to your computer and use it in GitHub Desktop.
import matplotlib
matplotlib.use("WXAgg")
from pylab import *
from scipy.stats import beta, uniform, norm
class BetaBandit(object):
def __init__(self, num_options=2, prior=(1.0,1.0)):
self.trials = zeros(shape=(num_options,), dtype=int)
self.successes = zeros(shape=(num_options,), dtype=int)
self.num_options = num_options
self.prior = prior
def add_result(self, trial_id, success):
self.trials[trial_id] = self.trials[trial_id] + 1
if (success):
self.successes[trial_id] = self.successes[trial_id] + 1
def get_recommendation(self):
sampled_theta = []
for i in range(self.num_options):
#Construct beta distribution for posterior
dist = beta(self.prior[0]+self.successes[i],
self.prior[1]+self.trials[i]-self.successes[i])
#Draw sample from beta distribution
sampled_theta += [ dist.rvs() ]
# Return the index of the sample with the largest value
return sampled_theta.index( max(sampled_theta) )
prior = beta(1,20)
def evaluate(N):
p = prior.rvs( (2,) )
b = BetaBandit()
u = uniform()
successes = 0.0
for i in range(N):
choice = b.get_recommendation()
if uniform.rvs() < p[choice]:
b.add_result(choice, True)
successes += 1
else:
b.add_result(choice, False)
return float(successes) / float(N)
def p_value(N, s):
empirical_ctr = s.astype(float) / N
std_error = sqrt(empirical_ctr[0]*(1.0-empirical_ctr[0])/N[0] + (empirical_ctr[1]*(1-empirical_ctr[1]))/N[1])
if (std_error == 0):
return 1
z_value = (empirical_ctr[1]-empirical_ctr[0])/std_error
p_value = 1 - norm().cdf(abs(z_value))
return p_value
def eighty_twenty(N):
p = prior.rvs( (2,) )
choices = array([0.0,0.0])
successes = 0.0
for i in range(N/10):
for k in range(2):
if uniform.rvs() < p[k]:
choices[k] = choices[k] + 1
successes += 1
if p_value((N/10,N/10), choices) < 0.05:
r = argmax(choices)
else:
r = 0
for i in range(8*N/10):
if uniform.rvs() < p[r]:
successes += 1
return float(successes) / float(N)
mmax = 11
results = zeros(shape=(3,mmax-1), dtype=float)
for M in arange(1,mmax)*50:
bandit_success = 0.0
eighty_twenty_success = 0.0
random_success = 0.0
nmax = 1000
for i in range(nmax):
bandit_success += evaluate(M)
random_success += prior.rvs()
eighty_twenty_success += eighty_twenty(M)
bandit_success /= nmax
eighty_twenty_success /= nmax
random_success /= nmax
results[0,M/50-1] = bandit_success
results[1,M/50-1] = eighty_twenty_success
results[2,M/50-1] = random_success
print "M = " + str(M)
print "Bandit success: " + str(bandit_success)
print "80/20 success: " + str(eighty_twenty_success)
print "Random success: " + str(random_success)
samples = arange(1,mmax)*50
clf()
plot(samples, results[0], label="Bandit")
plot(samples, results[1], label="80/20")
plot(samples, results[2], label="Random")
legend()
xlabel("number of samples")
ylabel("CTR")
show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment