Skip to content

Instantly share code, notes, and snippets.

@AlexandreAbraham
Last active April 28, 2020 15:54
Show Gist options
  • Save AlexandreAbraham/400b6f3b4c4b360206bc8d5abd4cef05 to your computer and use it in GitHub Desktop.
Save AlexandreAbraham/400b6f3b4c4b360206bc8d5abd4cef05 to your computer and use it in GitHub Desktop.
Hyperband experiment
from math import ceil
from functools import partial
from matplotlib import pyplot as plt
from cycler import cycler
import numpy as np
from decimal import Context
plt.rcParams['figure.constrained_layout.use'] = True
ctx = Context(prec=20)
eta = 3
# Using math.log leads to imprecision sometimes
# log(243, 3) = 4.9999999999
# Using decimal is more precise
def logeta(value):
return float(ctx.divide(ctx.ln(value), ctx.ln(eta)))
def hyperband(max_iter, eta, paper_version=False):
s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)
B = (s_max + 1) * max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)
hb_lines = []
#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinetely.
for s in reversed(range(s_max+1)):
n = int(ceil(int(B / max_iter / (s+1)) * eta ** s)) # initial number of configurations
if paper_version:
n = int(ceil(B / max_iter / (s+1) * eta ** s)) # initial number of configurations
r = max_iter * eta ** (-s) # initial number of iterations to run configurations for
lines = []
#### Begin Finite Horizon Successive Halving with (n,r)
for i in range(s+1):
# Run each of the n_i configs for r_i iterations and keep best n_i/eta
n_i = n*eta**(-i)
r_i = r*eta**(i)
lines.append((int(n_i), int(r_i)))
hb_lines.append(lines)
#### End Finite Horizon Successive Halving with (n,r)
return hb_lines
def our_hyperband(max_iter, eta, n_follows_eta=True, transfer_unused_budget=False):
s_max = int(logeta(max_iter))
B = (s_max + 1) * max_iter
hb_lines = []
unused_budget = 0
for s in range(s_max + 1):
# Compute all decreasing costs
n_i = 0
r_i = [max_iter // eta ** i for i in range(s + 1)]
available_budget = B + unused_budget
lines = []
for i in range(s + 1):
cost = sum([c * eta ** j for j, c in enumerate(r_i[i:])])
while (available_budget >= cost # There is enough budget
and (i == 0 # No constraint related to eta on first iter
or (not n_follows_eta or (n_i + 1) % eta != 0))):
n_i += 1
available_budget -= cost
lines.append((n_i, r_i[i]))
n_i *= eta
if transfer_unused_budget:
unused_budget = available_budget
hb_lines.append(list(reversed(lines)))
return hb_lines
def compute_cost(hb_lines):
cost = 0
for line in hb_lines:
cost += sum([n * r for (n, r) in line])
return cost
def estimate_ideal_budget(max_iter, eta):
s_max = int(logeta(max_iter))
return (s_max + 1) ** 2 * max_iter
methods = [
('Blog HB', hyperband),
('Paper HB', partial(hyperband, paper_version=True)),
('Our HB v1', our_hyperband),
('Our HB v2', partial(our_hyperband, n_follows_eta=False)),
('Our HB v3', partial(our_hyperband, n_follows_eta=False,
transfer_unused_budget=True)),
]
# Nice colors from seaborn
muted = ["#4878D0", "#EE854A", "#6ACC64", "#D65F5F", "#956CB4",
"#797979", "#8C613C", "#DC7EC0", "#D5BB67", "#82C6E2"]
plt.gca().set_prop_cycle(cycler(color=muted))
max_iters = np.arange(81, 278) # Ideal cost from 2000 to 10000
ideal_costs = [estimate_ideal_budget(max_iter.item(), eta) for max_iter in max_iters]
best = sum(ideal_costs)
plt.plot(ideal_costs, ideal_costs, label='Budget', color='gray', alpha=50)
for m_name, m_fun in methods:
costs = []
for max_iter in max_iters:
costs.append(compute_cost(m_fun(max_iter.item(), eta)))
plt.scatter(ideal_costs, costs, marker='.', label=m_name, s=50)
print('{}:\t{}'.format(m_name, sum(costs) / best))
plt.annotate(
'5 brackets\nR=1210', xy=(6156, 5546), xytext=(7000, 4500), color='r',
arrowprops=dict(arrowstyle='->', color='r', connectionstyle="arc3,rad=0.5"))
plt.annotate(
'6 brackets\nR=1458', xy=(8750, 8328), xytext=(8500, 6000), color='r',
arrowprops=dict(arrowstyle='->', color='r', connectionstyle="arc3,rad=-0.5"))
plt.legend()
plt.grid()
plt.gca().set_axisbelow(True)
plt.xlabel('Budget available')
plt.ylabel('Budget spent')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment