Last active
December 13, 2019 09:28
-
-
Save showgood163/346632be074b4436c10444dca14e1251 to your computer and use it in GitHub Desktop.
ax.client hack for non-repeating hyper parameter search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from botorch.utils.sampling import manual_seed | |
import warnings | |
from ax.utils.common.typeutils import not_none | |
from ax.modelbridge.modelbridge_utils import get_pending_observation_features | |
from ax.utils.common.logger import _round_floats_for_logging, get_logger | |
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from ax.core.generator_run import GeneratorRun | |
from ax.core.types import TParameterization | |
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy | |
from ax.modelbridge.registry import Models | |
from ax.service.ax_client import AxClient | |
import json | |
from ax.storage.json_store.encoder import object_to_json | |
from ax.storage.json_store.decoder import generation_strategy_from_json, object_from_json | |
logger = get_logger('ax.service.ax_client') | |
class AxClientExt(AxClient): | |
def __init__( | |
self, | |
generation_strategy: Optional[GenerationStrategy] = None, | |
db_settings: Any = None, | |
enforce_sequential_optimization: bool = True, | |
random_seed: Optional[int] = None, | |
verbose_logging: bool = True, | |
) -> None: | |
super(AxClientExt, self).__init__( | |
generation_strategy, | |
db_settings, | |
enforce_sequential_optimization, | |
random_seed, | |
verbose_logging, | |
) | |
self.random_generation_strategy = GenerationStrategy(steps=[ | |
GenerationStep( | |
model=Models.SOBOL, | |
num_arms=-1, | |
min_arms_observed=0, | |
enforce_num_arms=self._enforce_sequential_optimization, | |
model_kwargs=None, | |
) | |
]) | |
def _gen_new_random_generator_run(self, n: int = 1) -> GeneratorRun: | |
"""Generate new generator run for this experiment. | |
Args: | |
n: Number of arms to generate. | |
""" | |
new_data = self._get_new_data() | |
# If random seed is not set for this optimization, context manager does | |
# nothing; otherwise, it sets the random seed for torch, but only for the | |
# scope of this call. This is important because torch seed is set globally, | |
# so if we just set the seed without the context manager, it can have | |
# serious negative impact on the performance of the models that employ | |
# stochasticity. | |
with manual_seed(seed=self._random_seed) and warnings.catch_warnings(): | |
# Filter out GPYTorch warnings to avoid confusing users. | |
warnings.simplefilter("ignore") | |
return not_none(self.random_generation_strategy).gen( | |
experiment=self.experiment, | |
new_data=new_data, | |
n=n, | |
pending_observations=get_pending_observation_features( | |
experiment=self.experiment), | |
) | |
def get_next_random_trial(self) -> Tuple[TParameterization, int]: | |
""" | |
Generate trial with the next set of parameters to try in the iteration process. | |
Note: Service API currently supports only 1-arm trials. | |
Returns: | |
Tuple of trial parameterization, trial index | |
""" | |
trial = self.experiment.new_trial( | |
generator_run=self._gen_new_random_generator_run()) | |
logger.info( | |
f"Generated new trial {trial.index} with parameters " | |
f"{_round_floats_for_logging(item=not_none(trial.arm).parameters)}.") | |
trial.mark_dispatched() | |
self._updated_trials = [] | |
self._save_experiment_and_generation_strategy_to_db_if_possible() | |
return not_none(trial.arm).parameters, trial.index | |
@staticmethod | |
def load_from_json_file( | |
filepath: str = "ax_client_snapshot.json") -> "AxClientExt": | |
"""Restore an `AxClientExt` and its state from a JSON-serialized snapshot, | |
residing in a .json file by the given path. | |
""" | |
with open(filepath, "r") as file: # pragma: no cover | |
serialized = json.loads(file.read()) | |
client = AxClientExt.from_json_snapshot(serialized=serialized) | |
client.random_generation_strategy = GenerationStrategy(steps=[ | |
GenerationStep( | |
model=Models.SOBOL, | |
num_arms=-1, | |
min_arms_observed=0, | |
enforce_num_arms=client._enforce_sequential_optimization, | |
model_kwargs=None, | |
) | |
]) | |
return client | |
@staticmethod | |
def from_json_snapshot(serialized: Dict[str, Any]) -> "AxClientExt": | |
"""Recreate an `AxClientExt` from a JSON snapshot.""" | |
experiment = object_from_json(serialized.pop("experiment")) | |
serialized_generation_strategy = serialized.pop("generation_strategy") | |
ax_client = AxClientExt( | |
generation_strategy=generation_strategy_from_json( | |
generation_strategy_json=serialized_generation_strategy) | |
if serialized_generation_strategy is not None else None, | |
enforce_sequential_optimization=serialized.pop( | |
"_enforce_sequential_optimization"), | |
) | |
ax_client._experiment = experiment | |
ax_client._updated_trials = object_from_json( | |
serialized.pop("_updated_trials")) | |
return ax_client | |
def mark_the_latest_trial_abandoned(self): | |
self.experiment.trials[len(self.experiment.trials) - 1].mark_abandoned() | |
if __name__ == "__main__": | |
from os.path import isfile | |
def get_result(parameters): | |
return { | |
'fc': (abs(parameters['a'] + parameters['b'] + parameters['c']), 0.) | |
} | |
hyper_param_search_list = [ | |
{ | |
"name": "a", | |
"type": "range", | |
"bounds": [-4, 4], | |
"value_type": "int", | |
"log_scale": False, | |
}, | |
{ | |
"name": "b", | |
"type": "range", | |
"bounds": [-4, 4], | |
"value_type": "int", | |
"log_scale": False, | |
}, | |
{ | |
"name": "c", | |
"type": "range", | |
"bounds": [1, 9], | |
"value_type": "int", | |
"log_scale": False, | |
}, | |
] | |
ax_archive_file = 'tmp.json' | |
if isfile(ax_archive_file): | |
ax = AxClientExt.load_from_json_file(ax_archive_file) | |
else: | |
ax = AxClientExt() | |
ax.create_experiment( | |
name='tmp', | |
parameters=hyper_param_search_list, | |
objective_name="fc", | |
minimize=True, | |
parameter_constraints=None, | |
outcome_constraints=None) | |
# do not run exps w/ the same parameters twice | |
random_flag = False | |
while (True): | |
if random_flag: | |
parameters, trial_index = ax.get_next_random_trial() | |
random_flag = False | |
else: | |
parameters, trial_index = ax.get_next_trial() | |
for trial in ax.experiment.trials.values(): | |
if trial.arm.parameters == parameters and trial.status.is_terminal: | |
print('trial abandoned because of duplication') | |
ax.mark_the_latest_trial_abandoned() | |
random_flag = True | |
break | |
if random_flag: | |
continue | |
raw_data = get_result(parameters) | |
ax.complete_trial(trial_index=trial_index, raw_data=raw_data) | |
ax.save_to_json_file(ax_archive_file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment