Skip to content

Instantly share code, notes, and snippets.

@dharma6872
Created January 26, 2021 08:56
Show Gist options
  • Save dharma6872/40d15ea964362b2b4cb95e76ea47b557 to your computer and use it in GitHub Desktop.
Save dharma6872/40d15ea964362b2b4cb95e76ea47b557 to your computer and use it in GitHub Desktop.
[Part 1 Multi-Armed Bandit Problem] #강화학습
import numpy as np
# action 별 reward 반환 함수
def pull_bandit_arm(bandits, bandit_number):
# Pull arm in position bandit_number and return the obtained reward.
result = np.random.uniform()
return int(result <= bandits[bandit_number]) # 0 또는 1을 반환한다.
# action 을 결정하는 함수
# 탐험률 반영
def take_epsilon_greedy_action(epsilon, average_rewards):
# Take random action with probability epsilon, else take best action.
result = np.random.uniform()
if result < epsilon:
return np.random.randint(0, len(average_rewards)) # Random action
else:
return np.argmax(average_rewards) # Greedy action
if __name__ == "__main__":
# Probability of success of each bandit
bandits = [0.1, 0.3, 0.05, 0.55, 0.4]
num_iterations = 1000
epsilon = 0.1
# Store info to know which one is the best action in each moment
total_rewards = [0 for _ in range(len(bandits))]
total_attempts = [0 for _ in range(len(bandits))]
average_rewards = [0.0 for _ in range(len(bandits))]
#print(total_rewards, total_attempts, average_rewards)
for iteration in range(num_iterations + 1):
action = take_epsilon_greedy_action(epsilon, average_rewards)
reward = pull_bandit_arm(bandits, action)
#print("action: {}, reward; {}".format(action, reward))
# Store result
total_rewards[action] += reward
total_attempts[action] += 1
average_rewards[action] = total_rewards[action] / float(total_attempts[action])
if iteration % 100 == 0:
print('Average reward for bandits in iteration {} is {}'.format(iteration,
['{:.2f}'.format(elem) for elem in average_rewards]))
# Print results
best_bandit = np.argmax(average_rewards)
print('\nBest bandit is {} with an average observed reward of {:.4f}'
.format(best_bandit, average_rewards[best_bandit]))
print('Total observed reward in the {} episodes has been {}'
.format(num_iterations, sum(total_rewards)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment