Last active
June 18, 2023 23:04
-
-
Save sudeepraja/6f57e29f2854828793fb75ba5a09d152 to your computer and use it in GitHub Desktop.
Code to test different exploration strategies for Multi armed bandits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import math | |
number_of_bandits=10 | |
number_of_arms=10 | |
number_of_pulls=30000 | |
epsilon=0.3 | |
temperature=10.0 | |
min_temp = 0.1 | |
decay_rate=0.999 | |
def pick_arm(q_values,counts,strategy,success,failure): | |
global epsilon | |
global temperature | |
if strategy=="random": | |
return np.random.randint(0,len(q_values)) | |
if strategy=="greedy": | |
best_arms_value = np.max(q_values) | |
best_arms = np.argwhere(q_values==best_arms_value).flatten() | |
return best_arms[np.random.randint(0,len(best_arms))] | |
if strategy=="egreedy": | |
if np.random.random() > epsilon: | |
best_arms_value = np.max(q_values) | |
best_arms = np.argwhere(q_values==best_arms_value).flatten() | |
# epsilon=max(epsilon*decay_rate,min_temp) | |
return best_arms[np.random.randint(0,len(best_arms))] | |
else: | |
# epsilon=max(epsilon*decay_rate,min_temp) | |
return np.random.randint(0,len(q_values)) | |
if strategy=="boltzmann": | |
probs = np.exp(q_values/temperature) | |
temperature=max(temperature*decay_rate,min_temp) | |
probs = probs/np.sum(probs) | |
x=np.random.choice(len(probs), p=probs) | |
return x | |
if strategy=="ucb": | |
total_counts = np.sum(counts) | |
q_values_ucb = q_values + np.sqrt(np.reciprocal(counts+0.001)*2*math.log(total_counts+1.0)) | |
best_arms_value = np.max(q_values_ucb) | |
best_arms = np.argwhere(q_values_ucb==best_arms_value).flatten() | |
return best_arms[np.random.randint(0,len(best_arms))] | |
if strategy=="thompson": | |
sample_means = np.zeros(len(counts)) | |
for i in range(len(counts)): | |
sample_means[i]=np.random.beta(success[i]+1,failure[i]+1) | |
return np.argmax(sample_means) | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
for st in ["greedy","random","egreedy","boltzmann","ucb","thompson"]: | |
best_arm_counts = np.zeros((number_of_bandits,number_of_pulls)) | |
for i in range(number_of_bandits): | |
arm_means = np.random.rand(number_of_arms) | |
best_arm = np.argmax(arm_means) | |
q_values = np.zeros(number_of_arms) | |
counts = np.zeros(number_of_arms) | |
success=np.zeros(number_of_arms) | |
failure=np.zeros(number_of_arms) | |
for j in range(number_of_pulls): | |
a = pick_arm(q_values,counts,st,success,failure) | |
reward = np.random.binomial(1,arm_means[a]) | |
counts[a]+=1.0 | |
q_values[a]+= (reward-q_values[a])/counts[a] | |
success[a]+=reward | |
failure[a]+=(1-reward) | |
best_arm_counts[i][j] = counts[best_arm]*100.0/(j+1) | |
temperature = 10.0 | |
epsilon=0.3 | |
ys = np.mean(best_arm_counts,axis=0) | |
xs = range(len(ys)) | |
ax.plot(xs, ys,label = st) | |
plt.xlabel('Steps') | |
plt.ylabel('% Optimal Arm pulls') | |
plt.tight_layout() | |
plt.legend() | |
plt.ylim((0,110)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Best Work So Far on MABs. Hats off. Very concise and deep.