Last active
April 10, 2024 06:40
-
-
Save danmackinlay/64806fee0bd2554339a861a5091efe2a to your computer and use it in GitHub Desktop.
Ax + SLURM via `submitit` and `asyncio`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Asynchronous hyperparam search using [Ax](https://ax.dev/) and the submitit executor to run on SLURM | |
Refs | |
* [Ax Service API tutorial](https://ax.dev/tutorials/gpei_hartmann_service.html) | |
* [submitit/docs/examples.md](https://github.com/facebookincubator/submitit/blob/07f21fa1234e34151874c00d80c345e215af4967/docs/examples.md?plain=1#L121) | |
""" | |
import os | |
import asyncio | |
import time | |
import submitit | |
import cloudpickle | |
from ax.service.ax_client import AxClient, ObjectiveProperties | |
from ax.utils.measurement.synthetic_functions import hartmann6 | |
from ax.exceptions.generation_strategy import MaxParallelismReachedException | |
from ax.exceptions.core import DataRequiredError | |
from ax.core.base_trial import TrialStatus | |
import numpy as np | |
def is_successful(state): | |
""" | |
The API in submitit seems to have drifted and the list of possible job states is large, | |
so we define a helper function to check if a job has completed successfully. | |
""" | |
return state in ('DONE', 'COMPLETED', 'FINISHED') | |
def is_running(state): | |
""" | |
The API in submitit seems to have drifted and the list of possible job states is large, | |
so we define a helper function to check if a job is in progress. | |
""" | |
return state in ('RUNNING', 'UNKNOWN', 'PENDING', 'STARTED', 'QUEUED') | |
def init_or_load_ax_client( | |
json_file_path, | |
experiment_name, | |
parameters=[], | |
objectives=[], | |
parameter_constraints=[]): | |
""" | |
Initializes AxClient from a JSON file if available, otherwise creates a new experiment. | |
Note that if you load from disk, any parameters will be ignored. | |
""" | |
if os.path.exists(json_file_path): | |
try: | |
ax_client = AxClient.load_from_json_file(json_file_path) | |
print(f"Successfully loaded AxClient state from {json_file_path}.") | |
return ax_client | |
except Exception as e: | |
print(f"Failed to load AxClient from {json_file_path}: {e}. Initializing a new AxClient.") | |
# If the file does not exist or loading failed, create a new AxClient instance | |
ax_client = AxClient() | |
ax_client.create_experiment( | |
name=experiment_name, | |
parameters=parameters, | |
objectives=objectives, | |
parameter_constraints=parameter_constraints | |
) | |
print(f"Created a new experiment: {experiment_name}.") | |
return ax_client | |
class JobManager: | |
def __init__(self, executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=60): | |
self.executor = executor | |
self.ax_client = ax_client | |
self.ax_save_path = ax_save_path | |
self.jobs_state_path = jobs_state_path if jobs_state_path else ax_save_path + '.jobs.pkl' | |
self.jobs = {} # Track jobs by trial_index, storing (job, parameters) tuples | |
self.save_lock = asyncio.Lock() | |
self.wait_interval = wait_interval | |
async def safe_save_state(self): | |
async with self.save_lock: | |
# Save AxClient state | |
try: | |
self.ax_client.save_to_json_file(self.ax_save_path) | |
print(f"Successfully saved AxClient state to {self.ax_save_path}") | |
except Exception as e: | |
print(f"Failed to save AxClient state: {e}") | |
# Save jobs state | |
try: | |
with open(self.jobs_state_path, 'wb') as f: | |
cloudpickle.dump(self.jobs, f) | |
print(f"Successfully saved jobs state to {self.jobs_state_path}") | |
except Exception as e: | |
print(f"Failed to save jobs state: {e}") | |
@staticmethod | |
def load_state(executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=30): | |
# Attempt to load jobs state if exists | |
jobs_state_path = jobs_state_path if jobs_state_path else ax_save_path + '.jobs.pkl' | |
if os.path.exists(jobs_state_path): | |
try: | |
with open(jobs_state_path, 'rb') as f: | |
jobs = cloudpickle.load(f) | |
print(f"Successfully loaded jobs state from {jobs_state_path}") | |
# Initialize JobManager with loaded jobs | |
job_manager = JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) | |
job_manager.jobs = jobs | |
return job_manager | |
except Exception as e: | |
print(f"Failed to load jobs state: {e}") | |
# Fallback to a new JobManager instance if loading failed | |
return JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) | |
async def process_job(self, fn, parameters, trial_index, is_new=True): | |
if is_new: | |
job = self.executor.submit(fn, parameters) | |
self.jobs[trial_index] = (job, parameters) # Store job and parameters | |
else: | |
job, _ = self.jobs[trial_index] # Retrieve existing job and parameters | |
try: | |
print(f"Polling job {job.job_id} for trial {trial_index} in state {job.state}") | |
while is_running(job.state): | |
await asyncio.sleep(self.wait_interval) | |
await job.awaitable().result() # Refresh job state | |
if is_successful(job.state): | |
result = job.result() | |
self.ax_client.complete_trial(trial_index=trial_index, raw_data=result) | |
else: | |
raise ValueError(f"Job in unexpected state: {job.state}") | |
except Exception as e: | |
job_stderr = str(e) | |
self.ax_client.log_trial_failure(trial_index=trial_index, metadata={"stderr": job_stderr}) | |
print(f"Trial {trial_index} failed with error:\n{job_stderr}") | |
del self.jobs[trial_index] | |
await self.safe_save_state() | |
async def process_all_jobs(self, fn=None): | |
tasks = [] | |
for trial_index, (_, parameters) in list(self.jobs.items()): | |
task = asyncio.create_task(self.process_job(fn, parameters, trial_index, is_new=False)) | |
tasks.append(task) | |
await asyncio.gather(*tasks) | |
def reattach_incomplete_jobs(self): | |
for trial_index, (job, parameters) in self.jobs.items(): | |
if not is_successful(job.state): | |
self.ax_client.attach_trial(parameters=parameters) | |
print(f"Reattached parameters for incomplete job {trial_index}") | |
async def run_trials( | |
fn, executor, ax_client, trial_budget=25, | |
ax_save_path="experiments/ax_state.json", | |
job_manager_save_path='experiments/ax_state.jobs.pkl', wait_interval=30): | |
job_manager = JobManager.load_state(executor, ax_client, ax_save_path, job_manager_save_path, wait_interval=wait_interval) | |
# Process all serialized jobs before starting new ones | |
await job_manager.process_all_jobs(fn) | |
trials_submitted = 0 | |
while trials_submitted < trial_budget: | |
try: | |
parameters, trial_index = ax_client.get_next_trial() | |
asyncio.create_task(job_manager.process_job(fn, parameters, trial_index, is_new=True)) | |
trials_submitted += 1 | |
except (MaxParallelismReachedException, DataRequiredError) as e: | |
print(f"Waiting for jobs to complete due to: {type(e).__name__}") | |
await asyncio.sleep(wait_interval) | |
print(f"Total trials submitted: {trials_submitted}") | |
# # Wait for all tasks to complete | |
# await asyncio.gather(*tasks) | |
# I am too dumb to work out why async requires us to to this, but it does | |
# while job_manager.jobs: | |
# await asyncio.sleep(wait_interval) # Check periodically | |
if __name__ == "__main__": | |
# These imports will display an image if we execute this script from iterm2 | |
import matplotlib.pyplot as plt | |
import subprocess | |
import tempfile | |
# Load some example params, set up a function to use them | |
experiment_name = "optim_hartmann6_dev" | |
def model_trial(params): | |
""" | |
A wrapper function which takes the dictionary of parameters from Ax, and runs the synthetic fn hartmann6, returning the final validation loss, after waiting a while to keep things difficult. | |
""" | |
x = np.array([params.get(f"x{i+1}") for i in range(6)]) | |
# In our case, standard error is 0, since we are computing a synthetic function. | |
time.sleep(5) | |
## randomly fail 10% of the time | |
# I need to reseed this job since I just inherited my parent's random state | |
np.random.seed(int(time.time())) | |
if np.random.rand() < 0.1: | |
raise ValueError("Randomly failed") | |
return {"hartmann6": (hartmann6(x), 0.0), "l2norm": (np.sqrt((x**2).sum()), 0.0)} | |
#Set up an example experiment | |
experiment_params = [ | |
{ | |
"name": "x1", | |
"type": "range", | |
"bounds": [0.0, 1.0], | |
"value_type": "float", # Optional, defaults to inference from type of "bounds". | |
"log_scale": False, # Optional, defaults to False. | |
}, | |
{ | |
"name": "x2", | |
"type": "range", | |
"bounds": [0.0, 1.0], | |
}, | |
{ | |
"name": "x3", | |
"type": "range", | |
"bounds": [0.0, 1.0], | |
}, | |
{ | |
"name": "x4", | |
"type": "range", | |
"bounds": [0.0, 1.0], | |
}, | |
{ | |
"name": "x5", | |
"type": "range", | |
"bounds": [0.0, 1.0], | |
}, | |
{ | |
"name": "x6", | |
"type": "range", | |
"bounds": [0.0, 1.0], | |
}, | |
] | |
json_file_path = os.path.join('experiments', f'{experiment_name}.json') | |
objectives = {"hartmann6": ObjectiveProperties(minimize=True)} | |
ax_client = init_or_load_ax_client(json_file_path, experiment_name, experiment_params, objectives) | |
# Asynchronous experiment | |
executor = submitit.AutoExecutor(folder='jobs') | |
# executor = submitit.DebugExecutor(folder='jobs') | |
executor.update_parameters( | |
timeout_min=119, | |
# gpus_per_node=1, | |
slurm_account=os.getenv('SLURM_ACCOUNT'), | |
slurm_array_parallelism=20, | |
mem_gb=32, | |
cpus_per_task=8, | |
gpus_per_node=1, | |
name=experiment_name, | |
) | |
asyncio.run(run_trials(model_trial, executor, ax_client, trial_budget=13)) | |
#Did that work? | |
best_parameters, values = ax_client.get_best_parameters() | |
print(f"Best parameters: {best_parameters}, values: {values}") | |
best_objectives = [] | |
for trial in ax_client.experiment.trials.values(): | |
if trial.status == TrialStatus.COMPLETED: | |
best_objectives.append(trial.objective_mean) | |
print(trial.objective_mean) | |
else: | |
# best_objectives.append(np.nan) | |
print(trial.status) | |
best_objectives = np.array([best_objectives]) | |
plt.figure() | |
plt.plot(np.minimum.accumulate(best_objectives, axis=1).T, marker="o") | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
plt.savefig(tmpfile.name) | |
# Use imgcat to display the image in iTerm2 (beware this is an insecure operation on systems we do not control) | |
subprocess.run(f"imgcat {tmpfile.name}", shell=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment