Last active
November 14, 2022 09:57
-
-
Save okanyenigun/d8afec5fbcb5063be9b77540326030fc to your computer and use it in GitHub Desktop.
multi armed bandit thompson
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
class Pub: | |
def __init__(self, mu: float, sigma: float): | |
self.mu = mu | |
self.sigma = sigma | |
self.n = 0 | |
self.total_score = 0 | |
def draw_score_from_distribution(self) -> float: | |
#get a score from normal distribution | |
score = np.random.normal(self.mu, self.sigma) | |
#increase the visit | |
self.n += 1 | |
#add new score to total score | |
self.total_score += score | |
return score | |
class PubThompsonSampler(Pub): | |
def __init__(self, mu: float, sigma: float): | |
super().__init__(mu, sigma) | |
self.prior_mu = 0 | |
self.prior_sigma = 1000 | |
self.post_mu = self.prior_mu | |
self.post_sigma = self.prior_sigma | |
def get_mu_from_current_distribution(self) -> float: | |
mu = np.random.normal(self.post_mu, self.post_sigma) | |
return mu | |
def update_current_distribution(self) -> None: | |
self.post_sigma = np.sqrt((1 / self.prior_sigma **2 + self.n / self.sigma ** 2) **-1) | |
self.post_mu = (self.post_sigma**2) * ((self.prior_mu / self.prior_sigma**2) + (self.total_score / self.sigma**2)) | |
return | |
def draw_distributions(Pubs: list,i: int): | |
for p in Pubs: | |
samps = np.random.normal(p.post_mu, p.post_sigma, 10000) | |
sns.kdeplot(samps, shade=True) | |
plt.title('Iteration %s'%(i+1), fontsize=20) | |
plt.legend(['mu=%s'%(p.mu) for p in Pubs], fontsize=12) | |
plt.xlim(-20,20) | |
plt.xlabel('Average Score', fontsize=12) | |
plt.ylabel('Density', fontsize=12) | |
plt.show() | |
Pubs = [PubThompsonSampler(10,1), PubThompsonSampler(8,1), PubThompsonSampler(5,1)] | |
NDAYS=120 | |
for i in range(NDAYS): | |
if i < 10 or i == NDAYS-1: | |
draw_distributions(Pubs,i) | |
#get a sample from each posterior | |
post_samples = [p.get_mu_from_current_distribution() for p in Pubs] | |
#index of highest score | |
idx = post_samples.index(max(post_samples)) | |
#get a new sample from that distribution | |
s = Pubs[idx].draw_score_from_distribution() | |
#update that distributions posterior | |
Pubs[idx].update_current_distribution() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment