Skip to content

Instantly share code, notes, and snippets.

@edlanglois
Created February 21, 2017 02:18
Show Gist options
  • Save edlanglois/c6fb9b3c76f2e484f959fd5e63c23044 to your computer and use it in GitHub Desktop.
Save edlanglois/c6fb9b3c76f2e484f959fd5e63c23044 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# MIT License
#
# Copyright 2017 Eric Langlois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
def value_table(n, p=0.5):
scores = np.arange(-n, n + 1)
# shape: [moves_left, current_score]
shape = (n + 1, len(scores))
values = np.zeros(shape)
stop = np.zeros(shape, dtype=np.bool)
values[0, :] = np.maximum(scores, 0)
for k in range(1, n + 1):
values[k, 1:-1] = (1 - p) * values[k - 1, :-2] + p * values[k - 1, 2:]
values[k, 0] = (1 - p) * values[k - 1, 0] + p * values[k - 1, 1]
values[k, -1] = (1 - p) * values[k - 1, -2] + p * values[k - 1, -1]
stop[k, :] = values[k, :] < scores
values[k, :] = np.maximum(values[k, :], scores)
# For each k, when you should stop
first_score_to_stop = np.argmax(stop, axis=1)
stopping_scores = scores[first_score_to_stop]
# When never stopping, argmax(stop) = 0. Change to one past max score
stopping_scores[first_score_to_stop == 0] = scores[-1] + 1
return values, stop, stopping_scores
def print_value_table(values):
print('\n'.join(' '.join('{:>8.3g}'.format(x) for x in row)
for row in values))
def plot_range(n, pmin, pmax, num_steps):
ps = np.linspace(pmin, pmax, num_steps)
stopping_scores = np.stack([value_table(n, p)[2] for p in ps])
# Cut off 0 throws remaining parg
stopping_scores = stopping_scores[:, 1:]
vmin = np.min(stopping_scores)
vmax = np.max(stopping_scores)
ax = plt.gca()
cmap = cm.get_cmap('rainbow', vmax - vmin + 1)
image = ax.matshow(stopping_scores, extent=[1, n + 1, pmin, pmax],
origin='lower', aspect='auto', cmap=cmap)
ax.set_xlabel('Number of Throws Remaining')
ax.set_ylabel('Success Probability')
ax.set_title('Optimal Stopping Scores')
ax.set_xticks(range(1, n + 1))
cbar = plt.colorbar(image)
cbar.set_label('Stopping Score')
cbar.set_ticks(range(vmin, vmax + 1))
plt.show()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-n', type=int, default=10, help="Number of throws.")
parser.add_argument('-p', type=float, default=0.5,
help="Success probability.")
parser.add_argument(
'--plot-range', type=float, nargs=2, metavar=('MIN', 'MAX'),
help="Plot stopping scores for a range of success probabilities.")
parser.add_argument('--plot-steps', type=int, default=500,
help='Number of different p values to plot.')
args = parser.parse_args()
if args.plot_range is not None:
plot_range(args.n, args.plot_range[0], args.plot_range[1],
args.plot_steps)
else:
values, stop, stopping_scores = value_table(n=args.n, p=args.p)
print_value_table(values)
print(stopping_scores)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment