Created
February 12, 2022 22:18
-
-
Save tezansahu/eabc97ba1d6ec237a57dca479b9a8b8e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
multi_args = TrainingArguments( | |
output_dir="checkpoint", | |
seed=12345, | |
evaluation_strategy="steps", | |
eval_steps=100, | |
logging_strategy="steps", | |
logging_steps=100, | |
save_strategy="steps", | |
save_steps=100, | |
save_total_limit=3, # Since models are large, save only the last 3 checkpoints at any given time while training | |
metric_for_best_model='wups', | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=32, | |
remove_unused_columns=False, | |
num_train_epochs=5, | |
fp16=True, | |
dataloader_num_workers=8, | |
load_best_model_at_end=True, | |
) | |
# Initialize the actual collator and multimodal model | |
collator, model = createMultimodalVQACollatorAndModel("bert-base-uncased", "google/vit-base-patch16-224-in21k") | |
# Initialize the trainer with the dataset, collator, model, hyperparameters and evaluation metrics | |
multi_trainer = Trainer( | |
model, | |
multi_args, | |
train_dataset=dataset['train'], | |
eval_dataset=dataset['test'], | |
data_collator=collator, | |
compute_metrics=compute_metrics | |
) | |
# Start the training loop | |
train_multi_metrics = multi_trainer.train() | |
# Run the model on the evaluation set to obtain final metrics | |
eval_multi_metrics = multi_trainer.evaluate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment