Skip to content

Instantly share code, notes, and snippets.

@fohria
Last active April 24, 2018 17:40
Show Gist options
  • Save fohria/52f70225106336f3a9b72d85dedbdf10 to your computer and use it in GitHub Desktop.
Save fohria/52f70225106336f3a9b72d85dedbdf10 to your computer and use it in GitHub Desktop.
generate reward sequences for two armed bandit
""" just run this until you get a reasonable average for the sequence
for good arm we use rewards of 2-10 so average should be 6
for bad arm we use rewards of 1-8 so average should be 4.5 """
import numpy as np
_PRECISION = 0.05 # how far from ideal average our sequence can be
_SEQ_LENGTH = 50 # length of reward sequence
_BAD_ARM_LOW = 1
_BAD_ARM_HIGH = 8
_GOOD_ARM_LOW = 2
_GOOD_ARM_HIGH = 10
def generate_sequence(low, high, length):
goal_average = np.average(np.linspace(low, high, high-low+1, dtype=int))
while True:
print("generating new sequence...")
sequence = np.random.randint(low, high, length)
average = np.average(sequence)
if np.abs(average - goal_average) < _PRECISION:
print("found a sequence! average is {}".format(average))
break
else:
print("average of {} is not close enough!".format(average))
return sequence
print("generating bad arm...")
bad_arm = generate_sequence(_BAD_ARM_LOW, _BAD_ARM_HIGH, _SEQ_LENGTH)
print("generating good arm...")
good_arm = generate_sequence(_GOOD_ARM_LOW, _GOOD_ARM_HIGH, _SEQ_LENGTH)
print("ALL DONE! here are the sequences")
print("good arm sequence:")
print(', '.join(str(x) for x in good_arm))
print("bad arm sequence:")
print(', '.join(str(x) for x in bad_arm))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment