Instantly share code, notes, and snippets.

Embed
What would you like to do?
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import math
import toolz
import numpy as np
def stop_on_plateau(info, patience=10, tol=0.001, max_iter=None):
out = {}
for ident, records in info.items():
if max_iter is not None and len(records) > max_iter:
out[ident] = 0
elif len(records) > patience:
old = records[-patience]['score']
if all(d['score'] < old + tol for d in records[-patience:]):
out[ident] = 0
else:
out[ident] = 1
else:
out[ident] = 1
return out
class SHA:
def __init__(self, n, r, eta=3):
self.steps = 0
self.n = n
self.r = r
self.eta = eta
def fit(self, info):
n, r, eta = self.n, self.r, self.eta
n_i = math.floor(n * eta ** -self.steps)
r_i = r * eta**self.steps
if self.steps == 0:
self.steps = 1
assert len(info) == self.n
self.to_reach = {k: r_i for k in info}
return {k: 1 for k in info}
keep_training = stop_on_plateau(info)
if sum(keep_training.values()) == 0:
return keep_training
iteration_increase = len(info) / sum(keep_training.values())
info = {k: info[k] for k in keep_training}
calls = {k: record[-1]['partial_fit_calls']
for k, record in info.items()}
if calls != self.to_reach:
return {k: 1 for k in info}
best = toolz.topk(n_i, info, key=lambda k: info[k][-1]['score'])
if 1 <= len(best) < eta:
self._best_arm = max(best, key=lambda k: info[k][-1]['score'])
if len(best) in {0, 1}:
best = self._best_arm
return {best: 0}
to_reach = {k: r_i - info[k][-1]['partial_fit_calls']
for k in best}
self.to_reach = {k: int(v * iteration_increase) for k, v in to_reach.items()}
self.steps += 1
return {k: 1 for k in to_reach}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment