Skip to content

Instantly share code, notes, and snippets.

@priya-dwivedi
Created April 5, 2022 01:05
Show Gist options
  • Save priya-dwivedi/d69e3a8b1bb326db5996aef2bc86c831 to your computer and use it in GitHub Desktop.
Save priya-dwivedi/d69e3a8b1bb326db5996aef2bc86c831 to your computer and use it in GitHub Desktop.
Training Arguments for GEC
# defining training related arguments
batch_size = 16
args = Seq2SeqTrainingArguments(output_dir="/content/drive/MyDrive/c4_200m/weights",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=2e-5,
num_train_epochs=1,
weight_decay=0.01,
save_total_limit=2,
predict_with_generate=True,
fp16 = True,
gradient_accumulation_steps = 6,
eval_steps = 500,
save_steps = 500,
load_best_model_at_end=True,
logging_dir="/logs",
report_to="wandb")
# defining trainer using 🤗
trainer = Seq2SeqTrainer(model=model,
args=args,
train_dataset= GrammarDataset(train_dataset, tokenizer),
eval_dataset=GrammarDataset(test_dataset, tokenizer),
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics)
##Training the model
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment