Was used to train this classifier: https://huggingface.co/jantrienes/roberta-large-question-classifier
Last active
February 1, 2024 08:50
-
-
Save jantrienes/329479bdad6b2a239cfcea83b9159a8a to your computer and use it in GitHub Desktop.
Train a huggingface text classification model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: huggingface | |
channels: | |
- defaults | |
dependencies: | |
- python=3.11 | |
- pip | |
- pip: | |
- datasets==2.14.5 | |
- evaluate==0.4.1 | |
- torch==2.1.0 | |
- transformers==4.35.1 | |
- scikit-learn | |
- wandb | |
- numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import datasets | |
import evaluate | |
import numpy as np | |
import wandb | |
from sklearn.metrics import classification_report | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
DataCollatorWithPadding, | |
Trainer, | |
TrainingArguments, | |
) | |
def main(): | |
dataset_name = "launch/open_question_type" | |
text_column = "question" | |
label_column = "resolve_type" | |
model_name_or_path = "roberta-large" | |
output_dir = "output/roberta-large-question-classifier" | |
wandb.init(Path(output_dir).name) | |
dataset = datasets.load_dataset(dataset_name) | |
labels_unique = set( | |
ex[label_column] for split in dataset.keys() for ex in dataset[split] | |
) | |
label2id = {l: i for i, l in enumerate(labels_unique)} | |
id2label = {i: l for l, i in label2id.items()} | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
def preprocess_function(examples): | |
labels = [label2id[l] for l in examples[label_column]] | |
tokenized = tokenizer(examples[text_column], truncation=True) | |
tokenized["label"] = labels | |
return tokenized | |
tokenized_dataset = dataset.map(preprocess_function, batched=True) | |
def compute_metrics(eval_pred): | |
metric = evaluate.load("f1") | |
logits, labels = eval_pred | |
predictions = np.argmax(logits, axis=-1) | |
return metric.compute( | |
predictions=predictions, references=labels, average="macro" | |
) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_name_or_path, | |
num_labels=len(id2label), | |
id2label=id2label, | |
label2id=label2id, | |
) | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
learning_rate=2e-5, | |
warmup_ratio=0.1, | |
weight_decay=0, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=512, | |
num_train_epochs=30, | |
logging_strategy="steps", | |
logging_steps=50, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
save_total_limit=1, | |
metric_for_best_model="f1", | |
report_to="wandb", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
eval_dataset=tokenized_dataset["validation"], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |
print("Save best model...") | |
trainer.save_model() | |
trainer.save_state() | |
trainer.push_to_hub() | |
def predict(split_name): | |
out_path = Path(training_args.output_dir) | |
preds = trainer.predict(tokenized_dataset[split_name]) | |
y_pred = np.argmax(preds.predictions, axis=-1) | |
y_true = preds.label_ids | |
y_true = [id2label[y] for y in y_true] | |
y_pred = [id2label[y] for y in y_pred] | |
print(f"Evaluate {split_name}\n" + classification_report(y_true, y_pred)) | |
with open(out_path / f"y_true_{split_name}.txt", "w") as fout: | |
for i in y_true: | |
fout.write(str(i) + "\n") | |
with open(out_path / f"y_pred_{split_name}.txt", "w") as fout: | |
for i in y_pred: | |
fout.write(str(i) + "\n") | |
print("Run predictions with best model...") | |
predict(split_name="train") | |
predict(split_name="validation") | |
predict(split_name="test") | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
#SBATCH --nodes 1 | |
#SBATCH --gpus 1 | |
#SBATCH --time 05:00:00 | |
#SBATCH --partition=GPUampere,GPUhopper | |
eval "$(conda shell.bash hook)" | |
conda activate huggingface | |
python train.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment