Created
October 27, 2021 16:23
-
-
Save xwjiang2010/bc1ee5d2b16717fbb4367727acaf7e78 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter, deque, defaultdict | |
from typing import Mapping | |
import ray | |
from ray.tune.utils.placement_groups import PlacementGroupFactory | |
class TrialRunner: | |
def __init__(self, trial_executor, max_resource_requests: int): | |
self._pending_trials = [] # All trials we have configurations for | |
self._scheduled_trials = [] # Trials we want to start **next** | |
self._running_trials = set() # Running trials | |
self._paused_trials = set() # Paused trials | |
self._terminated_trials = set() # Terminated trials | |
self._trial_executor = trial_executor | |
self._training_futures = {} | |
self._saving_futures = {} | |
self._restoring_futures = {} | |
self._max_resource_requests = max_resource_requests | |
def add_trial(self, trial): | |
"""Add a trial. This will add it to the list of pending trials.""" | |
self._pending_trials.append(trial) | |
def start_trial(self, trial): | |
"""Start trial. If it cannot be started right now, | |
schedule instead.""" | |
assert trial in self._pending_trials | |
if self._trial_executor.start_trial(trial): | |
trial.status = "RUNNING" | |
self._pending_trials.remove(trial) | |
self._running_trials.add(trial) | |
self._train(trial) | |
else: | |
trial.status = "SCHEDULED" | |
self._pending_trials.remove(trial) | |
self._scheduled_trials.add(trial) | |
def pause_trial(self, trial): | |
assert trial in self._running_trials | |
if self._trial_executor.stop_trial(trial): | |
self._remove_trial_futures(trial) | |
trial.status = "PAUSED" | |
self._running_trials.remove(trial) | |
self._paused_trials.add(trial) | |
def stop_trial(self, trial): | |
assert trial in self._running_trials | |
if self._trial_executor.stop_trial(trial): | |
self._remove_trial_futures(trial) | |
trial.status = "TERMINATED" | |
self._running_trials.remove(trial) | |
self._terminated_trials.add(trial) | |
def step(self): | |
# Request resources for `max_resource_requests` trials | |
request_trials = self._scheduled_trials + self._pending_trials | |
can_start_trial = self._trial_executor.request_resources( | |
request_trials[:self._max_resource_requests]) | |
# If we can start a trial, start. | |
wait_for_future = True | |
if can_start_trial: | |
if self._start_next_trial(): | |
wait_for_future = False | |
# If we didn't start a trial, wait for futures | |
if wait_for_future: | |
futures = self._training_futures.values() + self._trial_executor.placement_futures | |
ready, not_ready = ray.wait(futures, num_returns=1) | |
for ready_future in ready: | |
if ready_future in self._training_futures: | |
self._process_training_future(ready_future) | |
elif ready_future in self._saving_futures: | |
self._process_saving_future(ready_future) | |
elif ready_future in self._restoring_futures: | |
self._process_restoring_future(ready_future) | |
# ... | |
else: | |
self._process_placement_future(ready_future) | |
def _remove_trial_futures(self, trial): | |
def _remove_from_future_dict(dt, trial): | |
for f, t in list(dt.items()): | |
if t == trial: | |
dt.pop(f) | |
_remove_from_future_dict(self._training_futures, trial) | |
_remove_from_future_dict(self._saving_futures, trial) | |
_remove_from_future_dict(self._restoring_futures, trial) | |
def _process_training_future(self, future): | |
trial = self._training_futures.pop(future) | |
results = ray.get(future) | |
self._process_trial_results(trial, results) | |
def _process_placement_future(self, future): | |
self._start_next_trial() | |
def _start_next_trial(self) -> bool: | |
"""Start next trial. Will start the first scheduled trial and loop | |
through the list until a trial was started. After that, will try | |
to start pending trials instead.""" | |
started_trial = False | |
unsuccessful_starts_scheduled = [] | |
while self._scheduled_trials: | |
try_start_trial = self._paused_trials.pop(0) | |
if not self.start_trial(try_start_trial): | |
unsuccessful_starts_scheduled.append(try_start_trial) | |
continue | |
started_trial = True | |
self._scheduled_trials = unsuccessful_starts_scheduled + self._scheduled_trials | |
if started_trial: | |
return started_trial | |
unsuccessful_starts_pending = [] | |
while self._pending_trials: | |
try_start_trial = self._paused_trials.pop(0) | |
if not self.start_trial(try_start_trial): | |
unsuccessful_starts_pending.append(try_start_trial) | |
continue | |
started_trial = True | |
break | |
self._pending_trials = unsuccessful_starts_pending + self._pending_trials | |
return started_trial | |
def _process_trial_results(self, trial, results): | |
for result in results: | |
self._process_trial_result(trial, result) | |
def _process_trial_result(self, trial, results): | |
pass | |
class TrialExecutor: | |
def __init__(self): | |
self._last_trial_request = [] | |
self._cached_pgf_actor_pgs = defaultdict(lambda: deque(maxlen=1)) | |
self._pg_manager = PlacementGroupManager() | |
self._placement_futures = {} | |
def request_resources(self, trials): | |
self._last_trial_request = trials | |
pgf_counter = Counter([trial.pgf for trial in trials]) | |
self._pg_manager.request_pgs(pgf_counter) | |
def start_trial(self, trial): | |
actor, pg = None, None | |
if len(self._cached_pgf_actor_pgs[trial.pgf]) > 0: | |
# There is a cached actor + pg (reuse_trials=True) | |
actor, pg = self._cached_pgf_actor_pgs[trial.pgf].popleft() | |
placement_future = actor.reset.remote(config=trial.config) | |
self._placement_futures.add(placement_future) | |
return True | |
if self._pg_manager.has_ready_pg(trial.pgf): | |
# There is no cached actor, but a PG is ready. Start new | |
# remote actor | |
pg = self._pg_manager.get_ready_pg(trial.pgf) | |
actor = self._create_trainable_actor(trial, pg) | |
placement_future = actor.ready.remote() | |
self._placement_futures.add(placement_future) | |
return True | |
return False | |
class PlacementGroupManager: | |
def __init__(self): | |
# Map of PGF to staged PGs | |
self._staged_pgf_pgs = defaultdict(list) | |
# Map of PGF to ready PGs | |
self._ready_pgf_pgs = defaultdict(list) | |
def request_pgs(self, pg_map: Mapping[PlacementGroupFactory, int]): | |
# Basically copy reconcile placement groups here | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment