Skip to content

Instantly share code, notes, and snippets.

@comckay
Last active September 28, 2021 13:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save comckay/74a98a30911e03a47ef1340e15e3bc1d to your computer and use it in GitHub Desktop.
Save comckay/74a98a30911e03a47ef1340e15e3bc1d to your computer and use it in GitHub Desktop.
from typing import List
import numpy as np
class UCB1:
def __init__(self, models: List[str]):
self.models, n_models = models, len(models)
self.model_successes = np.zeros((n_models))
self.model_tries = np.zeros((n_models))
def _increment_model_tries(self, model: str) -> None:
self.model_tries[self.models.index(model)] += 1
def _get_model_with_max_ucb(self) -> str:
ucb_numerator = 2 * np.log(np.sum(self.model_tries))
per_model_means = self.model_successes / self.model_tries
ucb1_estimates = per_model_means + np.sqrt(ucb_numerator / self.model_tries)
return self.models[np.nanargmax(ucb1_estimates)]
def select_model(self) -> str:
untested_models = np.nonzero(self.model_tries == 0)[0]
if untested_models.size == 0:
best_model_so_far = self._get_model_with_max_ucb()
self._increment_model_tries(best_model_so_far)
return best_model_so_far
else:
untested_model = self.models[untested_models[0]]
self._increment_model_tries(untested_model)
return untested_model
def reward_model(self, model: str) -> None:
if model not in self.models:
raise ValueError(f"model {model} not recognized")
model_index = self.models.index(model)
self.model_successes[model_index] += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment