Skip to content

Instantly share code, notes, and snippets.

@lebrice
Last active June 9, 2022 19:18
Show Gist options
  • Save lebrice/aadab906de6f7dd8a702f5e6e77bf536 to your computer and use it in GitHub Desktop.
Save lebrice/aadab906de6f7dd8a702f5e6e77bf536 to your computer and use it in GitHub Desktop.
Job array example
from dataclasses import dataclass
import os
from simple_parsing import ArgumentParser
from itertools import product
@dataclass
class ProblemConfig:
dataset: int = 0 # Which dataset ID to use.
rank: int = 0 # The rank of some matrix
alphabet_size: int = 0 # The size of the alphabet.
@dataclass
class Hyperparameters:
learning_rate: float = 1e-4
initialization: str = "xavier_uniform"
seed: int = 42
# Generate all the problem configurations and hyper-parameter configurations.
datasets = list(range(48))
ranks = list(range(48))
alphabet_size = list(range(48))
problem_configurations = [
ProblemConfig(dataset=dataset, rank=rank, alphabet_size=alphabet_size)
for dataset, rank, alphabet_size in zip(datasets, ranks, alphabet_size)
]
learning_rates = [1e-4, 1e-3, 1e-2, 1e-1]
initializations = ["xavier_uniform", "xavier_normal", "uniform"]
seeds = list(range(10))
hyper_parameter_configurations = [
Hyperparameters(
learning_rate=learning_rate, initialization=initialization, seed=seed
)
for learning_rate, initialization, seed in product(
learning_rates, initializations, seeds
)
]
def main():
parser = ArgumentParser()
job_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
# Depending on how you create your array jobs, choose either:
hparams_defaults = hyper_parameter_configurations[job_id]
# problem_for_this_job = problem_configurations[job_id]
parser.add_arguments(ProblemConfig, "problem")
# NOTE: This only changes the default values, the hparams can still be overwritten from the
# command-line.
parser.add_arguments(Hyperparameters, "hparams", default=hparams_defaults)
args = parser.parse_args()
problem: ProblemConfig = args.problem
hparams: Hyperparameters = args.hparams
train(problem=problem, hparams=hparams)
def train(problem: ProblemConfig, hparams: Hyperparameters):
# If you instead want to execute directly with the given hparams you could just pass the
# hparams you want to use to the `train` function, like `train(problem=some_problem,
# hparams=<whatever>)`.
print(f"problem: {problem}")
print(f"hparams: {hparams}")
... # your training code.
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment