Created
January 26, 2021 08:56
-
-
Save dharma6872/40d15ea964362b2b4cb95e76ea47b557 to your computer and use it in GitHub Desktop.
[Part 1 Multi-Armed Bandit Problem] #강화학습
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# action 별 reward 반환 함수 | |
def pull_bandit_arm(bandits, bandit_number): | |
# Pull arm in position bandit_number and return the obtained reward. | |
result = np.random.uniform() | |
return int(result <= bandits[bandit_number]) # 0 또는 1을 반환한다. | |
# action 을 결정하는 함수 | |
# 탐험률 반영 | |
def take_epsilon_greedy_action(epsilon, average_rewards): | |
# Take random action with probability epsilon, else take best action. | |
result = np.random.uniform() | |
if result < epsilon: | |
return np.random.randint(0, len(average_rewards)) # Random action | |
else: | |
return np.argmax(average_rewards) # Greedy action | |
if __name__ == "__main__": | |
# Probability of success of each bandit | |
bandits = [0.1, 0.3, 0.05, 0.55, 0.4] | |
num_iterations = 1000 | |
epsilon = 0.1 | |
# Store info to know which one is the best action in each moment | |
total_rewards = [0 for _ in range(len(bandits))] | |
total_attempts = [0 for _ in range(len(bandits))] | |
average_rewards = [0.0 for _ in range(len(bandits))] | |
#print(total_rewards, total_attempts, average_rewards) | |
for iteration in range(num_iterations + 1): | |
action = take_epsilon_greedy_action(epsilon, average_rewards) | |
reward = pull_bandit_arm(bandits, action) | |
#print("action: {}, reward; {}".format(action, reward)) | |
# Store result | |
total_rewards[action] += reward | |
total_attempts[action] += 1 | |
average_rewards[action] = total_rewards[action] / float(total_attempts[action]) | |
if iteration % 100 == 0: | |
print('Average reward for bandits in iteration {} is {}'.format(iteration, | |
['{:.2f}'.format(elem) for elem in average_rewards])) | |
# Print results | |
best_bandit = np.argmax(average_rewards) | |
print('\nBest bandit is {} with an average observed reward of {:.4f}' | |
.format(best_bandit, average_rewards[best_bandit])) | |
print('Total observed reward in the {} episodes has been {}' | |
.format(num_iterations, sum(total_rewards))) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment