import math
import toolz
import numpy as np
def stop_on_plateau(info, patience=10, tol=0.001, max_iter=None):
out = {}
for ident, records in info.items():
if max_iter is not None and len(records) > max_iter:
out[ident] = 0
elif len(records) > patience:
old = records[-patience]['score']
if all(d['score'] < old + tol for d in records[-patience:]):
out[ident] = 0
out[ident] = 1
out[ident] = 1
return out
class SHA:
def __init__(self, n, r, eta=3):
self.steps = 0
self.n = n
self.r = r
self.eta = eta
def fit(self, info):
n, r, eta = self.n, self.r, self.eta
n_i = math.floor(n * eta ** -self.steps)
r_i = r * eta**self.steps
if self.steps == 0:
self.steps = 1
assert len(info) == self.n
self.to_reach = {k: r_i for k in info}
return {k: 1 for k in info}
keep_training = stop_on_plateau(info)
if sum(keep_training.values()) == 0:
return keep_training
iteration_increase = len(info) / sum(keep_training.values())
info = {k: info[k] for k in keep_training}
calls = {k: record[-1]['partial_fit_calls']
for k, record in info.items()}
if calls != self.to_reach:
return {k: 1 for k in info}
best = toolz.topk(n_i, info, key=lambda k: info[k][-1]['score'])
if 1 <= len(best) < eta:
self._best_arm = max(best, key=lambda k: info[k][-1]['score'])
if len(best) in {0, 1}:
best = self._best_arm
return {best: 0}
to_reach = {k: r_i - info[k][-1]['partial_fit_calls']
for k in best}
self.to_reach = {k: int(v * iteration_increase) for k, v in to_reach.items()}
self.steps += 1
return {k: 1 for k in to_reach}
