Skip to content

Instantly share code, notes, and snippets.

@hadifar
Created August 19, 2023 11:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hadifar/de0923a839387689dd914dad602143d3 to your computer and use it in GitHub Desktop.
Save hadifar/de0923a839387689dd914dad602143d3 to your computer and use it in GitHub Desktop.
simple classifier code
import evaluate
import numpy as np
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
CACHE_DIR = 'cache/'
ROOT_DIR = 'dataset/'
MODEL_ID = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
bea_dataset = load_dataset('json',
data_files={'train': [ROOT_DIR + "train_ds.json"], 'test': [ROOT_DIR + "test_ds.json"]})
def set_seed(seed):
# REPRODUCIBILITY
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
id2label = {
0: "Acquisition",
1: "Company_Invest",
2: "Contract",
3: "Government_Invest",
4: "Market_Outlook",
5: "New_Product",
6: "Other"
}
id2label = dict(sorted(id2label.items(), key=lambda item: item[1]))
label2id = {label: iid for iid, label in id2label.items()}
def preprocess_label2id_fn(example):
example["label"] = label2id[example["label"]]
return example
def preprocess_tokenize_fn(examples):
return tokenizer(examples["text"], truncation=True)
bea_dataset = bea_dataset.map(preprocess_label2id_fn, batched=False)
bea_dataset = bea_dataset.map(preprocess_tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID, num_labels=len(id2label), id2label=id2label, label2id=label2id, cache_dir=CACHE_DIR,
)
training_args = TrainingArguments(
overwrite_output_dir=True,
output_dir="cache/logs/classifier",
learning_rate=5e-5,
gradient_accumulation_steps=4,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=10,
# weight_decay=0.01,
# evaluation_strategy="epoch",
# save_strategy="epoch",
# load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=bea_dataset["train"],
eval_dataset=bea_dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.save_model('model/classifier/')
tokenizer.save_pretrained('model/classifier/')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment