Created
January 14, 2019 10:48
-
-
Save JaeDukSeo/dd09aa2305620ed1ce567679c0af82e0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# m gradient base | |
m_pull_count = np.zeros((num_ep,num_bandit)) | |
m_estimation = np.zeros((num_ep,num_bandit)) | |
m_reward = np.zeros((num_ep,num_iter)) | |
m_optimal_pull = np.zeros((num_ep,num_iter)) | |
m_regret_total = np.zeros((num_ep,num_iter)) | |
for eps in range(num_ep): | |
temp_pull_count = np.zeros(num_bandit) | |
temp_estimation = np.zeros(num_bandit) + 1/num_bandit | |
temp_reward = np.zeros(num_iter) | |
temp_optimal_pull = np.zeros(num_iter) | |
temp_regret = np.zeros(num_iter) | |
temp_mean_reward = 0 | |
alpha = 0.8 | |
for iter in range(num_iter): | |
# select bandit / get reward /increase count / update estimate | |
pi = np.exp(temp_estimation) / np.sum(np.exp(temp_estimation)) | |
current_choice = np.random.choice(num_bandit,p=pi) | |
current_reward = 1 if np.random.uniform(0,1) < gt_prob[current_choice] else 0 | |
temp_pull_count[current_choice] = temp_pull_count[current_choice] + 1 | |
temp_mean_reward = temp_mean_reward + ((current_reward-temp_mean_reward))/(iter) if not iter==0 else ((current_reward-temp_mean_reward)) | |
mask = np.zeros(num_bandit) | |
mask[current_choice] = 1 | |
temp_estimation = (mask) * (temp_estimation+alpha*(current_reward-temp_mean_reward)*(1-pi)) + \ | |
(1-mask) * (temp_estimation-alpha*(current_reward-temp_mean_reward)*(pi)) | |
# update reward and optimal choice | |
temp_reward[iter] = current_reward if iter == 0 else temp_reward[iter-1] + current_reward | |
temp_optimal_pull[iter] = 1 if current_choice == optimal_choice else 0 | |
temp_regret[iter] = gt_prob[optimal_choice] - gt_prob[current_choice] if iter == 0 else temp_regret[iter-1] + (gt_prob[optimal_choice] - gt_prob[current_choice]) | |
m_pull_count[eps,:] = temp_pull_count | |
m_estimation[eps,:] = temp_estimation | |
m_reward[eps,:] = temp_reward | |
m_optimal_pull[eps,:] = temp_optimal_pull | |
m_regret_total[eps,:] = temp_regret | |
print('Ground Truth') | |
print(gt_prob) | |
print('Expected ') | |
print(np.around(m_estimation.mean(0),2)) | |
print('Expected Normalized') | |
print( | |
(gt_prob.max()-gt_prob.min())*(m_estimation.mean(0)-m_estimation.mean(0).min())/(m_estimation.mean(0).max()-m_estimation.mean(0).min()) + gt_prob.min() | |
) | |
m_estimation = (gt_prob.max()-gt_prob.min())*(m_estimation.mean(0)-m_estimation.mean(0).min())/(m_estimation.mean(0).max()-m_estimation.mean(0).min()) + gt_prob.min() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment