Skip to content

Instantly share code, notes, and snippets.

@okanyenigun
Last active November 14, 2022 09:57
Show Gist options
  • Save okanyenigun/d8afec5fbcb5063be9b77540326030fc to your computer and use it in GitHub Desktop.
Save okanyenigun/d8afec5fbcb5063be9b77540326030fc to your computer and use it in GitHub Desktop.
multi armed bandit thompson
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
class Pub:
def __init__(self, mu: float, sigma: float):
self.mu = mu
self.sigma = sigma
self.n = 0
self.total_score = 0
def draw_score_from_distribution(self) -> float:
#get a score from normal distribution
score = np.random.normal(self.mu, self.sigma)
#increase the visit
self.n += 1
#add new score to total score
self.total_score += score
return score
class PubThompsonSampler(Pub):
def __init__(self, mu: float, sigma: float):
super().__init__(mu, sigma)
self.prior_mu = 0
self.prior_sigma = 1000
self.post_mu = self.prior_mu
self.post_sigma = self.prior_sigma
def get_mu_from_current_distribution(self) -> float:
mu = np.random.normal(self.post_mu, self.post_sigma)
return mu
def update_current_distribution(self) -> None:
self.post_sigma = np.sqrt((1 / self.prior_sigma **2 + self.n / self.sigma ** 2) **-1)
self.post_mu = (self.post_sigma**2) * ((self.prior_mu / self.prior_sigma**2) + (self.total_score / self.sigma**2))
return
def draw_distributions(Pubs: list,i: int):
for p in Pubs:
samps = np.random.normal(p.post_mu, p.post_sigma, 10000)
sns.kdeplot(samps, shade=True)
plt.title('Iteration %s'%(i+1), fontsize=20)
plt.legend(['mu=%s'%(p.mu) for p in Pubs], fontsize=12)
plt.xlim(-20,20)
plt.xlabel('Average Score', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.show()
Pubs = [PubThompsonSampler(10,1), PubThompsonSampler(8,1), PubThompsonSampler(5,1)]
NDAYS=120
for i in range(NDAYS):
if i < 10 or i == NDAYS-1:
draw_distributions(Pubs,i)
#get a sample from each posterior
post_samples = [p.get_mu_from_current_distribution() for p in Pubs]
#index of highest score
idx = post_samples.index(max(post_samples))
#get a new sample from that distribution
s = Pubs[idx].draw_score_from_distribution()
#update that distributions posterior
Pubs[idx].update_current_distribution()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment