Skip to content

Instantly share code, notes, and snippets.

@cschell
Created January 28, 2022 08:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cschell/f4d2de50c9f34fddf38acef4b070ea92 to your computer and use it in GitHub Desktop.
Save cschell/f4d2de50c9f34fddf38acef4b070ea92 to your computer and use it in GitHub Desktop.
custom version of Optuna's GridSampler, changing the default behaviour to ignore failed trials; the code only changes one line from the original (L23)
from typing import List
from optuna.samplers import GridSampler
from optuna.study import Study
from optuna.trial import TrialState
class CustomGridSampler(GridSampler):
def _get_unvisited_grid_ids(self, study: Study) -> List[int]:
# List up unvisited grids based on already finished ones.
visited_grids = []
running_grids = []
# We directly query the storage to get trials here instead of `study.get_trials`,
# since some pruners such as `HyperbandPruner` use the study transformed
# to filter trials. See https://github.com/optuna/optuna/issues/2327 for details.
trials = study._storage.get_all_trials(study._study_id, deepcopy=False)
for t in trials:
if "grid_id" in t.system_attrs and self._same_search_space(
t.system_attrs["search_space"]
):
if t.state in [TrialState.COMPLETE, TrialState.PRUNED]:
visited_grids.append(t.system_attrs["grid_id"])
elif t.state == TrialState.RUNNING:
running_grids.append(t.system_attrs["grid_id"])
unvisited_grids = set(range(self._n_min_trials)) - set(visited_grids) - set(running_grids)
# If evaluations for all grids have been started, return grids that have not yet finished
# because all grids should be evaluated before stopping the optimization.
if len(unvisited_grids) == 0:
unvisited_grids = set(range(self._n_min_trials)) - set(visited_grids)
return list(unvisited_grids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment