Skip to content

Instantly share code, notes, and snippets.

@sudeepraja
Last active June 18, 2023 23:04
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sudeepraja/6f57e29f2854828793fb75ba5a09d152 to your computer and use it in GitHub Desktop.
Save sudeepraja/6f57e29f2854828793fb75ba5a09d152 to your computer and use it in GitHub Desktop.
Code to test different exploration strategies for Multi armed bandits
import numpy as np
import matplotlib.pyplot as plt
import math
number_of_bandits=10
number_of_arms=10
number_of_pulls=30000
epsilon=0.3
temperature=10.0
min_temp = 0.1
decay_rate=0.999
def pick_arm(q_values,counts,strategy,success,failure):
global epsilon
global temperature
if strategy=="random":
return np.random.randint(0,len(q_values))
if strategy=="greedy":
best_arms_value = np.max(q_values)
best_arms = np.argwhere(q_values==best_arms_value).flatten()
return best_arms[np.random.randint(0,len(best_arms))]
if strategy=="egreedy":
if np.random.random() > epsilon:
best_arms_value = np.max(q_values)
best_arms = np.argwhere(q_values==best_arms_value).flatten()
# epsilon=max(epsilon*decay_rate,min_temp)
return best_arms[np.random.randint(0,len(best_arms))]
else:
# epsilon=max(epsilon*decay_rate,min_temp)
return np.random.randint(0,len(q_values))
if strategy=="boltzmann":
probs = np.exp(q_values/temperature)
temperature=max(temperature*decay_rate,min_temp)
probs = probs/np.sum(probs)
x=np.random.choice(len(probs), p=probs)
return x
if strategy=="ucb":
total_counts = np.sum(counts)
q_values_ucb = q_values + np.sqrt(np.reciprocal(counts+0.001)*2*math.log(total_counts+1.0))
best_arms_value = np.max(q_values_ucb)
best_arms = np.argwhere(q_values_ucb==best_arms_value).flatten()
return best_arms[np.random.randint(0,len(best_arms))]
if strategy=="thompson":
sample_means = np.zeros(len(counts))
for i in range(len(counts)):
sample_means[i]=np.random.beta(success[i]+1,failure[i]+1)
return np.argmax(sample_means)
fig = plt.figure()
ax = fig.add_subplot(111)
for st in ["greedy","random","egreedy","boltzmann","ucb","thompson"]:
best_arm_counts = np.zeros((number_of_bandits,number_of_pulls))
for i in range(number_of_bandits):
arm_means = np.random.rand(number_of_arms)
best_arm = np.argmax(arm_means)
q_values = np.zeros(number_of_arms)
counts = np.zeros(number_of_arms)
success=np.zeros(number_of_arms)
failure=np.zeros(number_of_arms)
for j in range(number_of_pulls):
a = pick_arm(q_values,counts,st,success,failure)
reward = np.random.binomial(1,arm_means[a])
counts[a]+=1.0
q_values[a]+= (reward-q_values[a])/counts[a]
success[a]+=reward
failure[a]+=(1-reward)
best_arm_counts[i][j] = counts[best_arm]*100.0/(j+1)
temperature = 10.0
epsilon=0.3
ys = np.mean(best_arm_counts,axis=0)
xs = range(len(ys))
ax.plot(xs, ys,label = st)
plt.xlabel('Steps')
plt.ylabel('% Optimal Arm pulls')
plt.tight_layout()
plt.legend()
plt.ylim((0,110))
plt.show()
@SyedKhurramMahmud
Copy link

Best Work So Far on MABs. Hats off. Very concise and deep.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment