Created
August 19, 2021 09:14
-
-
Save josalhor/8ca386f0902f1523e0ee4cd17e337b37 to your computer and use it in GitHub Desktop.
Ax Pareto Frontier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Frontier: | |
def __init__(self, objectives): | |
self.frontier = [] | |
self.objectives = objectives | |
def load_df(self, df): | |
frontier = [] | |
for _, point in df.iterrows(): | |
if point['trial_status'] != 'COMPLETED': | |
continue | |
self.recalculate_frontier(point) | |
return frontier | |
def pareto_dominates(self, a, b): | |
for ob_name, ob_prop in self.objectives.items(): | |
minimize = ob_prop.minimize | |
if (minimize and a[ob_name] > b[ob_name]) or \ | |
(not minimize and a[ob_name] < b[ob_name]): | |
return False | |
return True | |
def recalculate_frontier(self, point): | |
for i in reversed(range(len(self.frontier))): | |
fp = self.frontier[i] | |
if self.pareto_dominates(point, fp): | |
del self.frontier[i] | |
elif self.pareto_dominates(fp, point): | |
return | |
# post condition of the loop: | |
# point is not dominated by any other point of the frontier | |
# ergo, it belongs on the frontier | |
self.frontier.append(point) | |
class OptimizedFrontier: | |
def __init__(self, objectives): | |
self.frontier = [] | |
self.objectives = objectives | |
self.best_theoretical_point = None | |
self.worst_theoretical_point = None | |
def load_df(self, df): | |
frontier = [] | |
for _, point in df.iterrows(): | |
if point['trial_status'] != 'COMPLETED': | |
continue | |
self.recalculate_frontier(point) | |
return frontier | |
def pareto_dominates(self, a, b): | |
for ob_name, ob_prop in self.objectives.items(): | |
minimize = ob_prop.minimize | |
if (minimize and a[ob_name] > b[ob_name]) or \ | |
(not minimize and a[ob_name] < b[ob_name]): | |
return False | |
return True | |
def recalculate_limit_points(self, point): | |
if self.best_theoretical_point is None: | |
assert self.worst_theoretical_point is None | |
self.best_theoretical_point = dict(point) | |
self.worst_theoretical_point = dict(point) | |
return | |
for ob_name, ob_prop in self.objectives.items(): | |
minimize = ob_prop.minimize | |
opt_fn, worst_fn = (min, max) if minimize else (max, min) | |
self.best_theoretical_point[ob_name] = opt_fn( | |
self.best_theoretical_point[ob_name], | |
point[ob_name] | |
) | |
self.worst_theoretical_point[ob_name] = worst_fn( | |
self.worst_theoretical_point[ob_name], | |
point[ob_name] | |
) | |
def recalculate_frontier(self, point): | |
if self.best_theoretical_point is None: | |
self.frontier = [point] | |
self.recalculate_limit_points(point) | |
return | |
if self.pareto_dominates(point, self.best_theoretical_point): | |
self.best_theoretical_point = dict(point) | |
self.worst_theoretical_point = dict(point) | |
self.frontier = [point] | |
return | |
if self.pareto_dominates(self.worst_theoretical_point, point): | |
return # worst than any other point in the frontier | |
# Note this is an optimization | |
# if the new point dominates a single point in the frontier | |
# then this new point cannot be dominated by another | |
# point on the frontier | |
dominates = False | |
for i in reversed(range(len(self.frontier))): | |
fp = self.frontier[i] | |
if self.pareto_dominates(point, fp): | |
del self.frontier[i] | |
dominates = True | |
elif (not dominates) and self.pareto_dominates(fp, point): | |
return | |
# post condition of the loop: | |
# point is not dominated by any other point of the frontier | |
# ergo, it belongs on the frontier | |
self.frontier.append(point) | |
# Notice it doesn't make sense to recalculate the points | |
# while iterating through the frontier, because a point cannot | |
# dominate itself | |
self.recalculate_limit_points(point) | |
import pandas as pd | |
df = pd.read_csv('./df.csv') # ax_client.get_trials_data_frame().to_csv('./df.csv') | |
from target import OBJ | |
frontier = Frontier(OBJ) | |
frontier.load_df(df) | |
print('Points probed:', len(df)) | |
print('Points frontier', len(frontier.frontier)) | |
frontier = OptimizedFrontier(OBJ) | |
frontier.load_df(df) | |
print('Points frontier', len(frontier.frontier)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment