Last active
November 29, 2022 05:23
-
-
Save patil-suraj/711b545ff0f8c884a1ff0ddf8d748bf9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from finetune_trainer import DataTrainingArguments, Seq2SeqTrainingArguments | |
from seq2seq_trainer import Seq2SeqTrainer | |
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer | |
from utils import Seq2SeqDataCollator, Seq2SeqDataset, build_compute_metrics_fn, freeze_embeds, freeze_params | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
data_args = DataTrainingArguments( | |
data_dir="wmt_en_ro", | |
task="translation", | |
max_source_length=128, | |
max_target_length=128, | |
val_max_target_length=128, | |
n_train=-1, | |
n_val=-1, | |
eval_beams=2, | |
) | |
# Evaluate during training and a bit more often than the default to be able to prune bad trials early. | |
# Disabling tqdm is a matter of preference. | |
training_args = Seq2SeqTrainingArguments( | |
"test", | |
evaluate_during_training=True, | |
predict_with_generate=True, | |
eval_steps=500, | |
num_train_epochs=1, | |
per_device_train_batch_size=64, | |
per_device_eval_batch_size=64, | |
fp16=True, | |
label_smoothing=0.1, | |
dropout=0.0, | |
disable_tqdm=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("sshleifer/student_marian_en_ro_6_3") | |
config = AutoConfig.from_pretrained("sshleifer/student_marian_en_ro_6_3") | |
# Get datasets | |
train_dataset = Seq2SeqDataset( | |
tokenizer, | |
type_path="train", | |
data_dir=data_args.data_dir, | |
n_obs=data_args.n_train, | |
max_target_length=data_args.max_target_length, | |
max_source_length=data_args.max_source_length, | |
) | |
eval_dataset = Seq2SeqDataset( | |
tokenizer, | |
type_path="val", | |
data_dir=data_args.data_dir, | |
n_obs=data_args.n_test, | |
max_target_length=data_args.val_max_target_length, | |
max_source_length=data_args.max_source_length, | |
) | |
def model_init(trial): | |
if trial != None: # suggest config params here | |
dropout = trial.suggest_float("dropout", 0, 0.4) | |
else: | |
dropout = 0.0 | |
config = AutoConfig.from_pretrained("sshleifer/student_marian_en_ro_6_3", dropout=dropout) | |
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/student_marian_en_ro_6_3", config=config) | |
freeze_embeds(model) | |
freeze_params(model.get_encoder()) | |
return model | |
compute_metrics_fn = build_compute_metrics_fn(data_args.task, tokenizer) | |
trainer = Seq2SeqTrainer( | |
model_init=model_init, | |
config=config, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), | |
compute_metrics=compute_metrics_fn, | |
data_args=data_args, | |
) | |
def hp_space(trial): | |
return { | |
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 0.5, log=True), | |
"label_smoothing": trial.suggest_float("label_smoothing", 0, 0.4), | |
"gradient_accumulation_steps": trial.suggest_categorical("gradient_accumulation_steps", [1, 8, 32, 128, 256]), | |
} | |
def objective(metrics): | |
return metrics["eval_bleu"] | |
def get_timeout(hours=24): | |
return hours * 3600 | |
# run for 10k trials or 12 hours, whichever comes first | |
trainer.hyperparameter_search( | |
direction="maximize", hp_space=hp_space, compute_objective=objective, n_trials=10000, timeout=get_timeout(hours=12) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment