Last active
June 8, 2022 08:26
-
-
Save LucaHermes/a1a3c1f903d2a0c41475600e5328ba82 to your computer and use it in GitHub Desktop.
Script to run wandb sweeps with a number of different seeds: Given some seeds, this sweep controller samples a model configuration and starts one run per seed with the exact same setup. I made this for a model that had a high variance to get somewhat reliable values for every configuration.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Wandb Sweep: Multiple seeds same config/parameters | |
Script to run wandb sweeps with a number of different seeds: | |
Given some seeds, this sweep controller samples a model configuration and starts one run per seed with the exact same setup. | |
I made this for a model that had a high variance to get somewhat reliable values for every configuration. | |
In order to group the runs with the different seeds in the wandb webapp, an additional value 'sweep_step_id' is passed | |
that is the same for each configuration. | |
To execute, run | |
$ python run_sweep_with_seeds.py | |
A local sweep controller will be started, you can run agents afterwards that get served by this controller script. | |
# Example sweep.yaml | |
program: train.py | |
controller: | |
type: local | |
method: random | |
metric: | |
name: accuracy | |
goal: maximize | |
parameters: | |
learning_rate: | |
min: 0.0001 | |
max: 0.1 | |
''' | |
import numpy as np | |
import wandb | |
import yaml | |
import time | |
import uuid | |
# number of seeds to run per configuration | |
N_SEEDS = 10 | |
# sweep setup file | |
SWEEP_CONFIG_FILE = 'sweep.yaml' | |
PROJECT = 'project_name' | |
# it is not necessary to set this seed, but I wanted my script to sample consistent seeds | |
np.random.seed(42) | |
seeds = np.random.randint(0, np.iinfo(np.int32).max, size=N_SEEDS) | |
with open(SWEEP_CONFIG_FILE, 'r') as f: | |
sweep_config = yaml.load(f) | |
# initialize a new sweep | |
sweep_id = wandb.sweep(sweep_config, project=PROJECT) | |
# start controller | |
sweep = wandb.controller(sweep_id) | |
n_scheduled = 0 | |
current_runs = 0 | |
while not sweep.done(): | |
sweep_step_id = str(uuid.uuid1()) | |
params = sweep.search() | |
for seed in seeds: | |
params.config['sweep_step_id'] = { 'value' : sweep_step_id } | |
params.config['seed'] = { 'value' : int(seed) } | |
n_scheduled += 1 | |
while n_scheduled > current_runs: | |
sweep._step() | |
sweep.schedule(params) | |
runs = sweep.stopping() | |
if runs: | |
sweep.stop_runs(runs) | |
print('waiting for an agent to take next job', '- current runs:', | |
current_runs, '/', n_scheduled, '\t', sweep._laststatus, ' - is done:', sweep.done()) | |
time.sleep(1.) | |
current_runs = len(sweep._sweep_runs) | |
sweep.print_status() | |
print('agent is ready for next job\n scheduling agent') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment