Last active
April 28, 2020 15:54
-
-
Save AlexandreAbraham/400b6f3b4c4b360206bc8d5abd4cef05 to your computer and use it in GitHub Desktop.
Hyperband experiment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import ceil | |
from functools import partial | |
from matplotlib import pyplot as plt | |
from cycler import cycler | |
import numpy as np | |
from decimal import Context | |
plt.rcParams['figure.constrained_layout.use'] = True | |
ctx = Context(prec=20) | |
eta = 3 | |
# Using math.log leads to imprecision sometimes | |
# log(243, 3) = 4.9999999999 | |
# Using decimal is more precise | |
def logeta(value): | |
return float(ctx.divide(ctx.ln(value), ctx.ln(eta))) | |
def hyperband(max_iter, eta, paper_version=False): | |
s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one) | |
B = (s_max + 1) * max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r) | |
hb_lines = [] | |
#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinetely. | |
for s in reversed(range(s_max+1)): | |
n = int(ceil(int(B / max_iter / (s+1)) * eta ** s)) # initial number of configurations | |
if paper_version: | |
n = int(ceil(B / max_iter / (s+1) * eta ** s)) # initial number of configurations | |
r = max_iter * eta ** (-s) # initial number of iterations to run configurations for | |
lines = [] | |
#### Begin Finite Horizon Successive Halving with (n,r) | |
for i in range(s+1): | |
# Run each of the n_i configs for r_i iterations and keep best n_i/eta | |
n_i = n*eta**(-i) | |
r_i = r*eta**(i) | |
lines.append((int(n_i), int(r_i))) | |
hb_lines.append(lines) | |
#### End Finite Horizon Successive Halving with (n,r) | |
return hb_lines | |
def our_hyperband(max_iter, eta, n_follows_eta=True, transfer_unused_budget=False): | |
s_max = int(logeta(max_iter)) | |
B = (s_max + 1) * max_iter | |
hb_lines = [] | |
unused_budget = 0 | |
for s in range(s_max + 1): | |
# Compute all decreasing costs | |
n_i = 0 | |
r_i = [max_iter // eta ** i for i in range(s + 1)] | |
available_budget = B + unused_budget | |
lines = [] | |
for i in range(s + 1): | |
cost = sum([c * eta ** j for j, c in enumerate(r_i[i:])]) | |
while (available_budget >= cost # There is enough budget | |
and (i == 0 # No constraint related to eta on first iter | |
or (not n_follows_eta or (n_i + 1) % eta != 0))): | |
n_i += 1 | |
available_budget -= cost | |
lines.append((n_i, r_i[i])) | |
n_i *= eta | |
if transfer_unused_budget: | |
unused_budget = available_budget | |
hb_lines.append(list(reversed(lines))) | |
return hb_lines | |
def compute_cost(hb_lines): | |
cost = 0 | |
for line in hb_lines: | |
cost += sum([n * r for (n, r) in line]) | |
return cost | |
def estimate_ideal_budget(max_iter, eta): | |
s_max = int(logeta(max_iter)) | |
return (s_max + 1) ** 2 * max_iter | |
methods = [ | |
('Blog HB', hyperband), | |
('Paper HB', partial(hyperband, paper_version=True)), | |
('Our HB v1', our_hyperband), | |
('Our HB v2', partial(our_hyperband, n_follows_eta=False)), | |
('Our HB v3', partial(our_hyperband, n_follows_eta=False, | |
transfer_unused_budget=True)), | |
] | |
# Nice colors from seaborn | |
muted = ["#4878D0", "#EE854A", "#6ACC64", "#D65F5F", "#956CB4", | |
"#797979", "#8C613C", "#DC7EC0", "#D5BB67", "#82C6E2"] | |
plt.gca().set_prop_cycle(cycler(color=muted)) | |
max_iters = np.arange(81, 278) # Ideal cost from 2000 to 10000 | |
ideal_costs = [estimate_ideal_budget(max_iter.item(), eta) for max_iter in max_iters] | |
best = sum(ideal_costs) | |
plt.plot(ideal_costs, ideal_costs, label='Budget', color='gray', alpha=50) | |
for m_name, m_fun in methods: | |
costs = [] | |
for max_iter in max_iters: | |
costs.append(compute_cost(m_fun(max_iter.item(), eta))) | |
plt.scatter(ideal_costs, costs, marker='.', label=m_name, s=50) | |
print('{}:\t{}'.format(m_name, sum(costs) / best)) | |
plt.annotate( | |
'5 brackets\nR=1210', xy=(6156, 5546), xytext=(7000, 4500), color='r', | |
arrowprops=dict(arrowstyle='->', color='r', connectionstyle="arc3,rad=0.5")) | |
plt.annotate( | |
'6 brackets\nR=1458', xy=(8750, 8328), xytext=(8500, 6000), color='r', | |
arrowprops=dict(arrowstyle='->', color='r', connectionstyle="arc3,rad=-0.5")) | |
plt.legend() | |
plt.grid() | |
plt.gca().set_axisbelow(True) | |
plt.xlabel('Budget available') | |
plt.ylabel('Budget spent') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment