Skip to content

Instantly share code, notes, and snippets.

@LucaHermes
Last active June 8, 2022 08:26
Show Gist options
  • Save LucaHermes/a1a3c1f903d2a0c41475600e5328ba82 to your computer and use it in GitHub Desktop.
Save LucaHermes/a1a3c1f903d2a0c41475600e5328ba82 to your computer and use it in GitHub Desktop.
Script to run wandb sweeps with a number of different seeds: Given some seeds, this sweep controller samples a model configuration and starts one run per seed with the exact same setup. I made this for a model that had a high variance to get somewhat reliable values for every configuration.
'''
Wandb Sweep: Multiple seeds same config/parameters
Script to run wandb sweeps with a number of different seeds:
Given some seeds, this sweep controller samples a model configuration and starts one run per seed with the exact same setup.
I made this for a model that had a high variance to get somewhat reliable values for every configuration.
In order to group the runs with the different seeds in the wandb webapp, an additional value 'sweep_step_id' is passed
that is the same for each configuration.
To execute, run
$ python run_sweep_with_seeds.py
A local sweep controller will be started, you can run agents afterwards that get served by this controller script.
# Example sweep.yaml
program: train.py
controller:
type: local
method: random
metric:
name: accuracy
goal: maximize
parameters:
learning_rate:
min: 0.0001
max: 0.1
'''
import numpy as np
import wandb
import yaml
import time
import uuid
# number of seeds to run per configuration
N_SEEDS = 10
# sweep setup file
SWEEP_CONFIG_FILE = 'sweep.yaml'
PROJECT = 'project_name'
# it is not necessary to set this seed, but I wanted my script to sample consistent seeds
np.random.seed(42)
seeds = np.random.randint(0, np.iinfo(np.int32).max, size=N_SEEDS)
with open(SWEEP_CONFIG_FILE, 'r') as f:
sweep_config = yaml.load(f)
# initialize a new sweep
sweep_id = wandb.sweep(sweep_config, project=PROJECT)
# start controller
sweep = wandb.controller(sweep_id)
n_scheduled = 0
current_runs = 0
while not sweep.done():
sweep_step_id = str(uuid.uuid1())
params = sweep.search()
for seed in seeds:
params.config['sweep_step_id'] = { 'value' : sweep_step_id }
params.config['seed'] = { 'value' : int(seed) }
n_scheduled += 1
while n_scheduled > current_runs:
sweep._step()
sweep.schedule(params)
runs = sweep.stopping()
if runs:
sweep.stop_runs(runs)
print('waiting for an agent to take next job', '- current runs:',
current_runs, '/', n_scheduled, '\t', sweep._laststatus, ' - is done:', sweep.done())
time.sleep(1.)
current_runs = len(sweep._sweep_runs)
sweep.print_status()
print('agent is ready for next job\n scheduling agent')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment