Skip to content

Instantly share code, notes, and snippets.

Last active April 10, 2024 06:40
Show Gist options
  • Save danmackinlay/64806fee0bd2554339a861a5091efe2a to your computer and use it in GitHub Desktop.
Save danmackinlay/64806fee0bd2554339a861a5091efe2a to your computer and use it in GitHub Desktop.
Ax + SLURM via `submitit` and `asyncio`
#!/usr/bin/env python
Asynchronous hyperparam search using [Ax]( and the submitit executor to run on SLURM
* [Ax Service API tutorial](
* [submitit/docs/](
import os
import asyncio
import time
import submitit
import cloudpickle
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.exceptions.core import DataRequiredError
from ax.core.base_trial import TrialStatus
import numpy as np
def is_successful(state):
The API in submitit seems to have drifted and the list of possible job states is large,
so we define a helper function to check if a job has completed successfully.
return state in ('DONE', 'COMPLETED', 'FINISHED')
def is_running(state):
The API in submitit seems to have drifted and the list of possible job states is large,
so we define a helper function to check if a job is in progress.
return state in ('RUNNING', 'UNKNOWN', 'PENDING', 'STARTED', 'QUEUED')
def init_or_load_ax_client(
Initializes AxClient from a JSON file if available, otherwise creates a new experiment.
Note that if you load from disk, any parameters will be ignored.
if os.path.exists(json_file_path):
ax_client = AxClient.load_from_json_file(json_file_path)
print(f"Successfully loaded AxClient state from {json_file_path}.")
return ax_client
except Exception as e:
print(f"Failed to load AxClient from {json_file_path}: {e}. Initializing a new AxClient.")
# If the file does not exist or loading failed, create a new AxClient instance
ax_client = AxClient()
print(f"Created a new experiment: {experiment_name}.")
return ax_client
class JobManager:
def __init__(self, executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=60):
self.executor = executor
self.ax_client = ax_client
self.ax_save_path = ax_save_path
self.jobs_state_path = jobs_state_path if jobs_state_path else ax_save_path + '.jobs.pkl' = {} # Track jobs by trial_index, storing (job, parameters) tuples
self.save_lock = asyncio.Lock()
self.wait_interval = wait_interval
async def safe_save_state(self):
async with self.save_lock:
# Save AxClient state
print(f"Successfully saved AxClient state to {self.ax_save_path}")
except Exception as e:
print(f"Failed to save AxClient state: {e}")
# Save jobs state
with open(self.jobs_state_path, 'wb') as f:
cloudpickle.dump(, f)
print(f"Successfully saved jobs state to {self.jobs_state_path}")
except Exception as e:
print(f"Failed to save jobs state: {e}")
def load_state(executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=30):
# Attempt to load jobs state if exists
jobs_state_path = jobs_state_path if jobs_state_path else ax_save_path + '.jobs.pkl'
if os.path.exists(jobs_state_path):
with open(jobs_state_path, 'rb') as f:
jobs = cloudpickle.load(f)
print(f"Successfully loaded jobs state from {jobs_state_path}")
# Initialize JobManager with loaded jobs
job_manager = JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) = jobs
return job_manager
except Exception as e:
print(f"Failed to load jobs state: {e}")
# Fallback to a new JobManager instance if loading failed
return JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval)
async def process_job(self, fn, parameters, trial_index, is_new=True):
if is_new:
job = self.executor.submit(fn, parameters)[trial_index] = (job, parameters) # Store job and parameters
job, _ =[trial_index] # Retrieve existing job and parameters
print(f"Polling job {job.job_id} for trial {trial_index} in state {job.state}")
while is_running(job.state):
await asyncio.sleep(self.wait_interval)
await job.awaitable().result() # Refresh job state
if is_successful(job.state):
result = job.result()
self.ax_client.complete_trial(trial_index=trial_index, raw_data=result)
raise ValueError(f"Job in unexpected state: {job.state}")
except Exception as e:
job_stderr = str(e)
self.ax_client.log_trial_failure(trial_index=trial_index, metadata={"stderr": job_stderr})
print(f"Trial {trial_index} failed with error:\n{job_stderr}")
await self.safe_save_state()
async def process_all_jobs(self, fn=None):
tasks = []
for trial_index, (_, parameters) in list(
task = asyncio.create_task(self.process_job(fn, parameters, trial_index, is_new=False))
await asyncio.gather(*tasks)
def reattach_incomplete_jobs(self):
for trial_index, (job, parameters) in
if not is_successful(job.state):
print(f"Reattached parameters for incomplete job {trial_index}")
async def run_trials(
fn, executor, ax_client, trial_budget=25,
job_manager_save_path='experiments/', wait_interval=30):
job_manager = JobManager.load_state(executor, ax_client, ax_save_path, job_manager_save_path, wait_interval=wait_interval)
# Process all serialized jobs before starting new ones
await job_manager.process_all_jobs(fn)
trials_submitted = 0
while trials_submitted < trial_budget:
parameters, trial_index = ax_client.get_next_trial()
asyncio.create_task(job_manager.process_job(fn, parameters, trial_index, is_new=True))
trials_submitted += 1
except (MaxParallelismReachedException, DataRequiredError) as e:
print(f"Waiting for jobs to complete due to: {type(e).__name__}")
await asyncio.sleep(wait_interval)
print(f"Total trials submitted: {trials_submitted}")
# # Wait for all tasks to complete
# await asyncio.gather(*tasks)
# I am too dumb to work out why async requires us to to this, but it does
# while
# await asyncio.sleep(wait_interval) # Check periodically
if __name__ == "__main__":
# These imports will display an image if we execute this script from iterm2
import matplotlib.pyplot as plt
import subprocess
import tempfile
# Load some example params, set up a function to use them
experiment_name = "optim_hartmann6_dev"
def model_trial(params):
A wrapper function which takes the dictionary of parameters from Ax, and runs the synthetic fn hartmann6, returning the final validation loss, after waiting a while to keep things difficult.
x = np.array([params.get(f"x{i+1}") for i in range(6)])
# In our case, standard error is 0, since we are computing a synthetic function.
## randomly fail 10% of the time
# I need to reseed this job since I just inherited my parent's random state
if np.random.rand() < 0.1:
raise ValueError("Randomly failed")
return {"hartmann6": (hartmann6(x), 0.0), "l2norm": (np.sqrt((x**2).sum()), 0.0)}
#Set up an example experiment
experiment_params = [
"name": "x1",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float", # Optional, defaults to inference from type of "bounds".
"log_scale": False, # Optional, defaults to False.
"name": "x2",
"type": "range",
"bounds": [0.0, 1.0],
"name": "x3",
"type": "range",
"bounds": [0.0, 1.0],
"name": "x4",
"type": "range",
"bounds": [0.0, 1.0],
"name": "x5",
"type": "range",
"bounds": [0.0, 1.0],
"name": "x6",
"type": "range",
"bounds": [0.0, 1.0],
json_file_path = os.path.join('experiments', f'{experiment_name}.json')
objectives = {"hartmann6": ObjectiveProperties(minimize=True)}
ax_client = init_or_load_ax_client(json_file_path, experiment_name, experiment_params, objectives)
# Asynchronous experiment
executor = submitit.AutoExecutor(folder='jobs')
# executor = submitit.DebugExecutor(folder='jobs')
# gpus_per_node=1,
), executor, ax_client, trial_budget=13))
#Did that work?
best_parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}, values: {values}")
best_objectives = []
for trial in ax_client.experiment.trials.values():
if trial.status == TrialStatus.COMPLETED:
# best_objectives.append(np.nan)
best_objectives = np.array([best_objectives])
plt.plot(np.minimum.accumulate(best_objectives, axis=1).T, marker="o")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
# Use imgcat to display the image in iTerm2 (beware this is an insecure operation on systems we do not control)"imgcat {}", shell=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment