Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Created February 12, 2022 22:18
Show Gist options
  • Save tezansahu/eabc97ba1d6ec237a57dca479b9a8b8e to your computer and use it in GitHub Desktop.
Save tezansahu/eabc97ba1d6ec237a57dca479b9a8b8e to your computer and use it in GitHub Desktop.
multi_args = TrainingArguments(
output_dir="checkpoint",
seed=12345,
evaluation_strategy="steps",
eval_steps=100,
logging_strategy="steps",
logging_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=3, # Since models are large, save only the last 3 checkpoints at any given time while training
metric_for_best_model='wups',
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
remove_unused_columns=False,
num_train_epochs=5,
fp16=True,
dataloader_num_workers=8,
load_best_model_at_end=True,
)
# Initialize the actual collator and multimodal model
collator, model = createMultimodalVQACollatorAndModel("bert-base-uncased", "google/vit-base-patch16-224-in21k")
# Initialize the trainer with the dataset, collator, model, hyperparameters and evaluation metrics
multi_trainer = Trainer(
model,
multi_args,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
data_collator=collator,
compute_metrics=compute_metrics
)
# Start the training loop
train_multi_metrics = multi_trainer.train()
# Run the model on the evaluation set to obtain final metrics
eval_multi_metrics = multi_trainer.evaluate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment