Skip to content

Instantly share code, notes, and snippets.

@xwjiang2010
Created October 27, 2021 16:23
Show Gist options
  • Save xwjiang2010/bc1ee5d2b16717fbb4367727acaf7e78 to your computer and use it in GitHub Desktop.
Save xwjiang2010/bc1ee5d2b16717fbb4367727acaf7e78 to your computer and use it in GitHub Desktop.
from collections import Counter, deque, defaultdict
from typing import Mapping
import ray
from ray.tune.utils.placement_groups import PlacementGroupFactory
class TrialRunner:
def __init__(self, trial_executor, max_resource_requests: int):
self._pending_trials = [] # All trials we have configurations for
self._scheduled_trials = [] # Trials we want to start **next**
self._running_trials = set() # Running trials
self._paused_trials = set() # Paused trials
self._terminated_trials = set() # Terminated trials
self._trial_executor = trial_executor
self._training_futures = {}
self._saving_futures = {}
self._restoring_futures = {}
self._max_resource_requests = max_resource_requests
def add_trial(self, trial):
"""Add a trial. This will add it to the list of pending trials."""
self._pending_trials.append(trial)
def start_trial(self, trial):
"""Start trial. If it cannot be started right now,
schedule instead."""
assert trial in self._pending_trials
if self._trial_executor.start_trial(trial):
trial.status = "RUNNING"
self._pending_trials.remove(trial)
self._running_trials.add(trial)
self._train(trial)
else:
trial.status = "SCHEDULED"
self._pending_trials.remove(trial)
self._scheduled_trials.add(trial)
def pause_trial(self, trial):
assert trial in self._running_trials
if self._trial_executor.stop_trial(trial):
self._remove_trial_futures(trial)
trial.status = "PAUSED"
self._running_trials.remove(trial)
self._paused_trials.add(trial)
def stop_trial(self, trial):
assert trial in self._running_trials
if self._trial_executor.stop_trial(trial):
self._remove_trial_futures(trial)
trial.status = "TERMINATED"
self._running_trials.remove(trial)
self._terminated_trials.add(trial)
def step(self):
# Request resources for `max_resource_requests` trials
request_trials = self._scheduled_trials + self._pending_trials
can_start_trial = self._trial_executor.request_resources(
request_trials[:self._max_resource_requests])
# If we can start a trial, start.
wait_for_future = True
if can_start_trial:
if self._start_next_trial():
wait_for_future = False
# If we didn't start a trial, wait for futures
if wait_for_future:
futures = self._training_futures.values() + self._trial_executor.placement_futures
ready, not_ready = ray.wait(futures, num_returns=1)
for ready_future in ready:
if ready_future in self._training_futures:
self._process_training_future(ready_future)
elif ready_future in self._saving_futures:
self._process_saving_future(ready_future)
elif ready_future in self._restoring_futures:
self._process_restoring_future(ready_future)
# ...
else:
self._process_placement_future(ready_future)
def _remove_trial_futures(self, trial):
def _remove_from_future_dict(dt, trial):
for f, t in list(dt.items()):
if t == trial:
dt.pop(f)
_remove_from_future_dict(self._training_futures, trial)
_remove_from_future_dict(self._saving_futures, trial)
_remove_from_future_dict(self._restoring_futures, trial)
def _process_training_future(self, future):
trial = self._training_futures.pop(future)
results = ray.get(future)
self._process_trial_results(trial, results)
def _process_placement_future(self, future):
self._start_next_trial()
def _start_next_trial(self) -> bool:
"""Start next trial. Will start the first scheduled trial and loop
through the list until a trial was started. After that, will try
to start pending trials instead."""
started_trial = False
unsuccessful_starts_scheduled = []
while self._scheduled_trials:
try_start_trial = self._paused_trials.pop(0)
if not self.start_trial(try_start_trial):
unsuccessful_starts_scheduled.append(try_start_trial)
continue
started_trial = True
self._scheduled_trials = unsuccessful_starts_scheduled + self._scheduled_trials
if started_trial:
return started_trial
unsuccessful_starts_pending = []
while self._pending_trials:
try_start_trial = self._paused_trials.pop(0)
if not self.start_trial(try_start_trial):
unsuccessful_starts_pending.append(try_start_trial)
continue
started_trial = True
break
self._pending_trials = unsuccessful_starts_pending + self._pending_trials
return started_trial
def _process_trial_results(self, trial, results):
for result in results:
self._process_trial_result(trial, result)
def _process_trial_result(self, trial, results):
pass
class TrialExecutor:
def __init__(self):
self._last_trial_request = []
self._cached_pgf_actor_pgs = defaultdict(lambda: deque(maxlen=1))
self._pg_manager = PlacementGroupManager()
self._placement_futures = {}
def request_resources(self, trials):
self._last_trial_request = trials
pgf_counter = Counter([trial.pgf for trial in trials])
self._pg_manager.request_pgs(pgf_counter)
def start_trial(self, trial):
actor, pg = None, None
if len(self._cached_pgf_actor_pgs[trial.pgf]) > 0:
# There is a cached actor + pg (reuse_trials=True)
actor, pg = self._cached_pgf_actor_pgs[trial.pgf].popleft()
placement_future = actor.reset.remote(config=trial.config)
self._placement_futures.add(placement_future)
return True
if self._pg_manager.has_ready_pg(trial.pgf):
# There is no cached actor, but a PG is ready. Start new
# remote actor
pg = self._pg_manager.get_ready_pg(trial.pgf)
actor = self._create_trainable_actor(trial, pg)
placement_future = actor.ready.remote()
self._placement_futures.add(placement_future)
return True
return False
class PlacementGroupManager:
def __init__(self):
# Map of PGF to staged PGs
self._staged_pgf_pgs = defaultdict(list)
# Map of PGF to ready PGs
self._ready_pgf_pgs = defaultdict(list)
def request_pgs(self, pg_map: Mapping[PlacementGroupFactory, int]):
# Basically copy reconcile placement groups here
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment