Skip to content

Instantly share code, notes, and snippets.

@patil-suraj
Last active November 29, 2022 05:23
Show Gist options
  • Save patil-suraj/711b545ff0f8c884a1ff0ddf8d748bf9 to your computer and use it in GitHub Desktop.
Save patil-suraj/711b545ff0f8c884a1ff0ddf8d748bf9 to your computer and use it in GitHub Desktop.
import logging
from finetune_trainer import DataTrainingArguments, Seq2SeqTrainingArguments
from seq2seq_trainer import Seq2SeqTrainer
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from utils import Seq2SeqDataCollator, Seq2SeqDataset, build_compute_metrics_fn, freeze_embeds, freeze_params
# Setup logging
logging.basicConfig(level=logging.INFO)
data_args = DataTrainingArguments(
data_dir="wmt_en_ro",
task="translation",
max_source_length=128,
max_target_length=128,
val_max_target_length=128,
n_train=-1,
n_val=-1,
eval_beams=2,
)
# Evaluate during training and a bit more often than the default to be able to prune bad trials early.
# Disabling tqdm is a matter of preference.
training_args = Seq2SeqTrainingArguments(
"test",
evaluate_during_training=True,
predict_with_generate=True,
eval_steps=500,
num_train_epochs=1,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
fp16=True,
label_smoothing=0.1,
dropout=0.0,
disable_tqdm=True,
)
tokenizer = AutoTokenizer.from_pretrained("sshleifer/student_marian_en_ro_6_3")
config = AutoConfig.from_pretrained("sshleifer/student_marian_en_ro_6_3")
# Get datasets
train_dataset = Seq2SeqDataset(
tokenizer,
type_path="train",
data_dir=data_args.data_dir,
n_obs=data_args.n_train,
max_target_length=data_args.max_target_length,
max_source_length=data_args.max_source_length,
)
eval_dataset = Seq2SeqDataset(
tokenizer,
type_path="val",
data_dir=data_args.data_dir,
n_obs=data_args.n_test,
max_target_length=data_args.val_max_target_length,
max_source_length=data_args.max_source_length,
)
def model_init(trial):
if trial != None: # suggest config params here
dropout = trial.suggest_float("dropout", 0, 0.4)
else:
dropout = 0.0
config = AutoConfig.from_pretrained("sshleifer/student_marian_en_ro_6_3", dropout=dropout)
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/student_marian_en_ro_6_3", config=config)
freeze_embeds(model)
freeze_params(model.get_encoder())
return model
compute_metrics_fn = build_compute_metrics_fn(data_args.task, tokenizer)
trainer = Seq2SeqTrainer(
model_init=model_init,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=compute_metrics_fn,
data_args=data_args,
)
def hp_space(trial):
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 0.5, log=True),
"label_smoothing": trial.suggest_float("label_smoothing", 0, 0.4),
"gradient_accumulation_steps": trial.suggest_categorical("gradient_accumulation_steps", [1, 8, 32, 128, 256]),
}
def objective(metrics):
return metrics["eval_bleu"]
def get_timeout(hours=24):
return hours * 3600
# run for 10k trials or 12 hours, whichever comes first
trainer.hyperparameter_search(
direction="maximize", hp_space=hp_space, compute_objective=objective, n_trials=10000, timeout=get_timeout(hours=12)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment