Skip to content

Instantly share code, notes, and snippets.

@Dref360
Created March 30, 2022 15:41
Show Gist options
  • Save Dref360/dd2b678ae28fb36a075ff72cf096c4e6 to your computer and use it in GitHub Desktop.
Save Dref360/dd2b678ae28fb36a075ff72cf096c4e6 to your computer and use it in GitHub Desktop.
Train a HF Pipeline on a dataset. Taken from their course.
import argparse
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
LABEL_COL = "label"
TEXT_COL = "text"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("dataset_args", nargs='+', help="List of argument to load the "
"Dataset to train on (available on HF Hub)")
parser.add_argument("--pretrained_pipeline", default="distilbert-base-uncased",
type=str, help="Pretrained pipeline to download (tokenizer and model)")
parser.add_argument("--text_column", default=TEXT_COL)
parser.add_argument("--label_column", default=LABEL_COL)
parser.add_argument("--ckpt_path", default="./ckpt")
return parser.parse_args()
def main(args):
ds = load_dataset(*args.dataset_args)
if args.label_column not in ds["train"].column_names or args.text_column not in ds["train"].column_names:
raise ValueError(f"Expecting {args.label_column} and {args.text_column} in dataset"
f" found {ds['train'].column_names}")
if args.text_column != TEXT_COL:
ds = ds.rename_column(args.text_column, TEXT_COL)
if args.label_column != LABEL_COL:
ds = ds.rename_column(args.label_column, LABEL_COL)
num_classes = ds["train"].features[LABEL_COL].num_classes
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_pipeline)
def preprocess_function(examples):
return tokenizer(examples[TEXT_COL], truncation=True)
tokenized_ds = ds.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(args.pretrained_pipeline,
num_labels=num_classes)
training_args = TrainingArguments(
output_dir=args.ckpt_path,
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds["train"],
eval_dataset=tokenized_ds["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
if __name__ == '__main__':
main(parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment