This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
accelerator.print(f"Loading the model from user defined `load_model`") | |
model = load_model(id2label, hyperparams['model_name']) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=hyperparams['learning_rate']) | |
''' | |
Reduces the primacy effect of early training examples. Read more: | |
https://stackoverflow.com/a/55942518/2415539 | |
''' | |
lr_scheduler = get_linear_schedule_with_warmup( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
encoded_data_path = os.path.join(SAGEMAKER_LOCAL_TRAINING_DIR, 'encoded_data') | |
accelerator.print("Loading dataset from "+encoded_data_path) | |
dataset = load_from_disk(encoded_data_path) | |
train_data_len = dataset['train'].num_rows | |
valid_data_len = dataset['valid'].num_rows | |
train_data = dataset['train'] | |
valid_data = dataset['valid'] | |
train_data.set_format(type="torch") | |
valid_data.set_format(type="torch") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Accelerator allows us to train in both single CPU, multi CPU, single GPU and multi GPU | |
accelerator = Accelerator( | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
log_with='wandb' | |
) | |
accelerator.print("Accelerator has determined the num processes to be: ", accelerator.num_processes) | |
# initialize wandb | |
if accelerator.is_main_process: | |
os.environ['WANDB_API_KEY'] = config.wandb_api_key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_gradient_accum_batch_size(hyperparams): | |
batch_size = hyperparams['train_batch_size'] | |
gradient_accumulation_steps = 1 | |
if batch_size > hyperparams['max_gpu_batch_size']: | |
gradient_accumulation_steps = batch_size // hyperparams['max_gpu_batch_size'] | |
batch_size = hyperparams['max_gpu_batch_size'] | |
return gradient_accumulation_steps, batch_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hyperparams = environment.read_hyperparameters() | |
is_tuning_job = '_tuning_objective_metric' in hyperparams | |
run_name = f"run_{config.run_num}" | |
job_type_str = '' | |
if is_tuning_job: | |
job_type_str = 'tuning' | |
elif config.is_comparison: | |
job_type_str = f"{hyperparams['model_name']}_training" | |
else: | |
job_type_str = 'training' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
METRIC_NAME=objective_metric_f1 | |
METRIC_REGEX=.*objective_metric_f1=(.*?); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
metric_definitions = [{"Name": config.metric_name, "Regex": config.metric_regex}] | |
objective_metric_name = config.metric_name | |
tuner = HyperparameterTuner( | |
estimator, | |
objective_metric_name, | |
tunable_hyperparams, | |
metric_definitions, | |
objective_type=config.objective_type, | |
max_jobs=config.max_jobs, | |
max_parallel_jobs=config.max_parallel_jobs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tunable_hyperparams = {} | |
for key, tunable_param in config.tunable_hyperparams.items(): | |
if isinstance(tunable_param, tuple): | |
# its continuous | |
if isinstance(tunable_param[0], int): | |
tunable_hyperparams[key] = IntegerParameter(tunable_param[0], tunable_param[1]) | |
else: | |
tunable_hyperparams[key] = ContinuousParameter(tunable_param[0], tunable_param[1]) | |
else: | |
# its categorical |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Constant Hyperparams | |
CONSTANT_HP_EVAL_BATCH_SIZE=4 | |
CONSTANT_HP_MAX_GPU_BATCH_SIZE=4 | |
CONSTANT_HP_SEED=1337 | |
CONSTANT_HP_WARMUP_STEPS=100 | |
# Tunable Hyperparams | |
TUNABLE_HP_EPOCHS=5 -> 30 | |
TUNABLE_HP_TRAIN_BATCH_SIZE=8, 16, 32, 64 | |
TUNABLE_HP_LEARNING_RATE=.00001 -> .01 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
role = get_role(config.execution_role) | |
estimator_kwargs = get_estimator_kwargs(config, role, config.constant_hyperparams) | |
if config.use_distrbuted: | |
estimator_kwargs = add_distributed_config(config, estimator_kwargs, config.constant_hyperparams) | |
estimator = Estimator(**estimator_kwargs) |