Skip to content

Instantly share code, notes, and snippets.

@stucchio
Created February 6, 2014 07:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save stucchio/8839891 to your computer and use it in GitHub Desktop.
Save stucchio/8839891 to your computer and use it in GitHub Desktop.
from pylab import *
import random
from scipy.stats import beta, uniform
prior = beta(1,1)
class Bandit(object):
def __init__(self):
self.history = [(1.0,1.0), (1.0,1.0)]
def add_history(self, choice, success):
if (success):
self.history[choice] = (self.history[choice][0], self.history[choice][1]+1)
else:
self.history[choice] = (self.history[choice][0]+1, self.history[choice][1])
def get_choice(self):
rank = [ x[1]/(x[0]+x[1]) for x in self.history ]
return argmax(rank)
class Random(object):
def __init__(self):
pass
def add_history(self, choice, success):
pass
def get_choice(self):
return random.choice([0,1])
u = uniform()
def evaluate_profit(algo):
p = prior.rvs((2,))
payoff = 0.0
for i in range(3):
choice = algo.get_choice()
if u.rvs() < p[choice]:
payoff += 1
algo.add_history(choice, True)
else:
algo.add_history(choice, False)
return payoff
if __name__=="__main__":
N = 100000
random_profit = 0.0
bandit_profit = 0.0
for i in range(N):
random_profit += evaluate_profit(Random())
bandit_profit += evaluate_profit(Bandit())
random_profit /= N
bandit_profit /= N
print "Random Profit: " + str(random_profit)
print "Bandit Profit: " + str(bandit_profit)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment