Skip to content

Instantly share code, notes, and snippets.

@stsievert
Last active August 1, 2018 18:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stsievert/c675b3a237a60efbd01dcb112e29115b to your computer and use it in GitHub Desktop.
Save stsievert/c675b3a237a60efbd01dcb112e29115b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import math
import toolz
import numpy as np
def stop_on_plateau(info, patience=10, tol=0.001, max_iter=None):
out = {}
for ident, records in info.items():
if max_iter is not None and len(records) > max_iter:
out[ident] = 0
elif len(records) > patience:
old = records[-patience]['score']
if all(d['score'] < old + tol for d in records[-patience:]):
out[ident] = 0
else:
out[ident] = 1
else:
out[ident] = 1
return out
class SHA:
def __init__(self, n, r, eta=3):
self.steps = 0
self.n = n
self.r = r
self.eta = eta
def fit(self, info):
n, r, eta = self.n, self.r, self.eta
n_i = math.floor(n * eta ** -self.steps)
r_i = r * eta**self.steps
if self.steps == 0:
self.steps = 1
assert len(info) == self.n
self.to_reach = {k: r_i for k in info}
return {k: 1 for k in info}
keep_training = stop_on_plateau(info)
if sum(keep_training.values()) == 0:
return keep_training
iteration_increase = len(info) / sum(keep_training.values())
info = {k: info[k] for k in keep_training}
calls = {k: record[-1]['partial_fit_calls']
for k, record in info.items()}
if calls != self.to_reach:
return {k: 1 for k in info}
best = toolz.topk(n_i, info, key=lambda k: info[k][-1]['score'])
if 1 <= len(best) < eta:
self._best_arm = max(best, key=lambda k: info[k][-1]['score'])
if len(best) in {0, 1}:
best = self._best_arm
return {best: 0}
to_reach = {k: r_i - info[k][-1]['partial_fit_calls']
for k in best}
self.to_reach = {k: int(v * iteration_increase) for k, v in to_reach.items()}
self.steps += 1
return {k: 1 for k in to_reach}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment