Skip to content

Instantly share code, notes, and snippets.

@showgood163
Last active December 13, 2019 09:28
Show Gist options
  • Save showgood163/346632be074b4436c10444dca14e1251 to your computer and use it in GitHub Desktop.
Save showgood163/346632be074b4436c10444dca14e1251 to your computer and use it in GitHub Desktop.
ax.client hack for non-repeating hyper parameter search
from botorch.utils.sampling import manual_seed
import warnings
from ax.utils.common.typeutils import not_none
from ax.modelbridge.modelbridge_utils import get_pending_observation_features
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from typing import Any, Dict, List, Optional, Tuple, Union
from ax.core.generator_run import GeneratorRun
from ax.core.types import TParameterization
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient
import json
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.decoder import generation_strategy_from_json, object_from_json
logger = get_logger('ax.service.ax_client')
class AxClientExt(AxClient):
def __init__(
self,
generation_strategy: Optional[GenerationStrategy] = None,
db_settings: Any = None,
enforce_sequential_optimization: bool = True,
random_seed: Optional[int] = None,
verbose_logging: bool = True,
) -> None:
super(AxClientExt, self).__init__(
generation_strategy,
db_settings,
enforce_sequential_optimization,
random_seed,
verbose_logging,
)
self.random_generation_strategy = GenerationStrategy(steps=[
GenerationStep(
model=Models.SOBOL,
num_arms=-1,
min_arms_observed=0,
enforce_num_arms=self._enforce_sequential_optimization,
model_kwargs=None,
)
])
def _gen_new_random_generator_run(self, n: int = 1) -> GeneratorRun:
"""Generate new generator run for this experiment.
Args:
n: Number of arms to generate.
"""
new_data = self._get_new_data()
# If random seed is not set for this optimization, context manager does
# nothing; otherwise, it sets the random seed for torch, but only for the
# scope of this call. This is important because torch seed is set globally,
# so if we just set the seed without the context manager, it can have
# serious negative impact on the performance of the models that employ
# stochasticity.
with manual_seed(seed=self._random_seed) and warnings.catch_warnings():
# Filter out GPYTorch warnings to avoid confusing users.
warnings.simplefilter("ignore")
return not_none(self.random_generation_strategy).gen(
experiment=self.experiment,
new_data=new_data,
n=n,
pending_observations=get_pending_observation_features(
experiment=self.experiment),
)
def get_next_random_trial(self) -> Tuple[TParameterization, int]:
"""
Generate trial with the next set of parameters to try in the iteration process.
Note: Service API currently supports only 1-arm trials.
Returns:
Tuple of trial parameterization, trial index
"""
trial = self.experiment.new_trial(
generator_run=self._gen_new_random_generator_run())
logger.info(
f"Generated new trial {trial.index} with parameters "
f"{_round_floats_for_logging(item=not_none(trial.arm).parameters)}.")
trial.mark_dispatched()
self._updated_trials = []
self._save_experiment_and_generation_strategy_to_db_if_possible()
return not_none(trial.arm).parameters, trial.index
@staticmethod
def load_from_json_file(
filepath: str = "ax_client_snapshot.json") -> "AxClientExt":
"""Restore an `AxClientExt` and its state from a JSON-serialized snapshot,
residing in a .json file by the given path.
"""
with open(filepath, "r") as file: # pragma: no cover
serialized = json.loads(file.read())
client = AxClientExt.from_json_snapshot(serialized=serialized)
client.random_generation_strategy = GenerationStrategy(steps=[
GenerationStep(
model=Models.SOBOL,
num_arms=-1,
min_arms_observed=0,
enforce_num_arms=client._enforce_sequential_optimization,
model_kwargs=None,
)
])
return client
@staticmethod
def from_json_snapshot(serialized: Dict[str, Any]) -> "AxClientExt":
"""Recreate an `AxClientExt` from a JSON snapshot."""
experiment = object_from_json(serialized.pop("experiment"))
serialized_generation_strategy = serialized.pop("generation_strategy")
ax_client = AxClientExt(
generation_strategy=generation_strategy_from_json(
generation_strategy_json=serialized_generation_strategy)
if serialized_generation_strategy is not None else None,
enforce_sequential_optimization=serialized.pop(
"_enforce_sequential_optimization"),
)
ax_client._experiment = experiment
ax_client._updated_trials = object_from_json(
serialized.pop("_updated_trials"))
return ax_client
def mark_the_latest_trial_abandoned(self):
self.experiment.trials[len(self.experiment.trials) - 1].mark_abandoned()
if __name__ == "__main__":
from os.path import isfile
def get_result(parameters):
return {
'fc': (abs(parameters['a'] + parameters['b'] + parameters['c']), 0.)
}
hyper_param_search_list = [
{
"name": "a",
"type": "range",
"bounds": [-4, 4],
"value_type": "int",
"log_scale": False,
},
{
"name": "b",
"type": "range",
"bounds": [-4, 4],
"value_type": "int",
"log_scale": False,
},
{
"name": "c",
"type": "range",
"bounds": [1, 9],
"value_type": "int",
"log_scale": False,
},
]
ax_archive_file = 'tmp.json'
if isfile(ax_archive_file):
ax = AxClientExt.load_from_json_file(ax_archive_file)
else:
ax = AxClientExt()
ax.create_experiment(
name='tmp',
parameters=hyper_param_search_list,
objective_name="fc",
minimize=True,
parameter_constraints=None,
outcome_constraints=None)
# do not run exps w/ the same parameters twice
random_flag = False
while (True):
if random_flag:
parameters, trial_index = ax.get_next_random_trial()
random_flag = False
else:
parameters, trial_index = ax.get_next_trial()
for trial in ax.experiment.trials.values():
if trial.arm.parameters == parameters and trial.status.is_terminal:
print('trial abandoned because of duplication')
ax.mark_the_latest_trial_abandoned()
random_flag = True
break
if random_flag:
continue
raw_data = get_result(parameters)
ax.complete_trial(trial_index=trial_index, raw_data=raw_data)
ax.save_to_json_file(ax_archive_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment