Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created August 12, 2023 22:57
Show Gist options
  • Save unixpickle/05d35364a07c9f3534c9012a610827eb to your computer and use it in GitHub Desktop.
Save unixpickle/05d35364a07c9f3534c9012a610827eb to your computer and use it in GitHub Desktop.
Guess the number
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"import math\n",
"import random\n",
"from typing import Callable\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MAX_GUESSES = 5\n",
"MIN_NUM = 1\n",
"MAX_NUM = 99\n",
"\n",
"class Environment:\n",
" def __init__(self):\n",
" self._number = random.randrange(MIN_NUM, MAX_NUM + 1)\n",
" self._guesses = 0\n",
" self._total_guesses = 0\n",
" self._wins = 0\n",
" self._losses = 0\n",
"\n",
" @property\n",
" def total_guesses(self) -> int:\n",
" return self._total_guesses\n",
"\n",
" def guess(self, n: int) -> tuple[int, bool]:\n",
" self._guesses += 1\n",
" self._total_guesses += 1\n",
" res = 0\n",
" if n < self._number:\n",
" res = -1\n",
" elif n > self._number:\n",
" res = 1\n",
" success = not res\n",
" exhausted_guesses = self._guesses >= MAX_GUESSES\n",
" if success or exhausted_guesses:\n",
" if success:\n",
" self._wins += 1\n",
" else:\n",
" self._losses += 1\n",
" self._number = random.randrange(MIN_NUM, MAX_NUM + 1)\n",
" self._guesses = 0\n",
" return res, False\n",
" return res, True\n",
"\n",
" def forfeit(self):\n",
" self._losses += 1\n",
" self._number = random.randrange(MIN_NUM, MAX_NUM + 1)\n",
" self._guesses = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class ConfidenceInterval:\n",
" mean: float\n",
" std: float\n",
"\n",
" def __repr__(self) -> str:\n",
" return f\"{self.mean:.02f} +/- {self.std:.02f}\"\n",
"\n",
"def guesses_until_victory(policy: Callable[[Environment], None]) -> int:\n",
" env = Environment()\n",
" policy(env)\n",
" return env.total_guesses\n",
"\n",
"def avg_guesses_until_victory(\n",
" policy: Callable[[Environment], None],\n",
" trials: int = 10000,\n",
") -> ConfidenceInterval:\n",
" counts = [guesses_until_victory(policy) for _ in range(trials)]\n",
" mean = sum(counts) / len(counts)\n",
" std = math.sqrt(sum((x - mean)**2 for x in counts) / (len(counts) ** 2))\n",
" return ConfidenceInterval(mean=mean, std=std)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def binary_search(env: Environment, min: int=MIN_NUM-1, max: int=MAX_NUM+1):\n",
" while True:\n",
" guess = (min + max) // 2\n",
" result, can_continue = env.guess(guess)\n",
" if result == 0:\n",
" # We won, so we are done.\n",
" return\n",
" if can_continue:\n",
" if result == -1:\n",
" min = guess\n",
" else:\n",
" max = guess\n",
" else:\n",
" min = MIN_NUM\n",
" max = MAX_NUM\n",
"\n",
"avg_guesses_until_victory(binary_search)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def bail_binary_search(bail_bound: int, env: Environment):\n",
" while True:\n",
" result, _ = env.guess(bail_bound)\n",
" if result == 0:\n",
" # We got lucky!\n",
" return\n",
" elif result == 1:\n",
" # the real number is below the bound\n",
" break\n",
" # Bail because the real number is too high.\n",
" env.forfeit()\n",
" return binary_search(env, max=bail_bound+1)\n",
"\n",
"xs = range(1, 100)\n",
"ys = [avg_guesses_until_victory(lambda env: bail_binary_search(x, env)).mean for x in xs]\n",
"print('minimum value achieved at:', ys.index(min(ys)) + xs[0])\n",
"plt.plot(xs, ys, label='bail and restart if above x')\n",
"plt.axhline(avg_guesses_until_victory(binary_search).mean, color='r', label='binary search')\n",
"plt.xlabel('maximum first guess')\n",
"plt.ylabel('avg number of guesses')\n",
"plt.legend()\n",
"plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment