Last active
April 24, 2018 17:40
-
-
Save fohria/52f70225106336f3a9b72d85dedbdf10 to your computer and use it in GitHub Desktop.
generate reward sequences for two armed bandit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" just run this until you get a reasonable average for the sequence | |
for good arm we use rewards of 2-10 so average should be 6 | |
for bad arm we use rewards of 1-8 so average should be 4.5 """ | |
import numpy as np | |
_PRECISION = 0.05 # how far from ideal average our sequence can be | |
_SEQ_LENGTH = 50 # length of reward sequence | |
_BAD_ARM_LOW = 1 | |
_BAD_ARM_HIGH = 8 | |
_GOOD_ARM_LOW = 2 | |
_GOOD_ARM_HIGH = 10 | |
def generate_sequence(low, high, length): | |
goal_average = np.average(np.linspace(low, high, high-low+1, dtype=int)) | |
while True: | |
print("generating new sequence...") | |
sequence = np.random.randint(low, high, length) | |
average = np.average(sequence) | |
if np.abs(average - goal_average) < _PRECISION: | |
print("found a sequence! average is {}".format(average)) | |
break | |
else: | |
print("average of {} is not close enough!".format(average)) | |
return sequence | |
print("generating bad arm...") | |
bad_arm = generate_sequence(_BAD_ARM_LOW, _BAD_ARM_HIGH, _SEQ_LENGTH) | |
print("generating good arm...") | |
good_arm = generate_sequence(_GOOD_ARM_LOW, _GOOD_ARM_HIGH, _SEQ_LENGTH) | |
print("ALL DONE! here are the sequences") | |
print("good arm sequence:") | |
print(', '.join(str(x) for x in good_arm)) | |
print("bad arm sequence:") | |
print(', '.join(str(x) for x in bad_arm)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment