Skip to content

Instantly share code, notes, and snippets.

@behitek
Created July 28, 2022 06:45
Show Gist options
  • Save behitek/f47b25020b8084f9738c16cf75fb7e3d to your computer and use it in GitHub Desktop.
Save behitek/f47b25020b8084f9738c16cf75fb7e3d to your computer and use it in GitHub Desktop.
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
dataset = load_dataset('json', data_files=['data/train_qa_vi_mailong.jsonl'])
checkpoint_name = "xlm-roberta-large"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
def preprocess_function(examples):
return tokenizer(examples["question"], examples["context"], padding="max_length", truncation=True, max_length=512)
tokenized_data = dataset.map(preprocess_function, batched=True)
tokenized_data = tokenized_data["train"].train_test_split(test_size=0.1, seed=1996)
print(tokenized_data)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_name, num_labels=2)
training_args = TrainingArguments(
output_dir=f"{checkpoint_name}-finetuned-retrieval",
learning_rate=1e-6,
auto_find_batch_size=True,
num_train_epochs=2,
save_total_limit=5,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_data["train"],
eval_dataset=tokenized_data["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
# print(tokenized_data["test"]["label"])
results = trainer.predict(tokenized_data["test"])
# results = trainer.evaluate()
# print(results.label_ids)
labels = tokenized_data["test"]["label"]
preds = results.predictions
count_true = 0
count_false = 0
count_correct = 0
predict_true = 0
predict_false = 0
for label, pred in zip(labels, preds):
pred = np.argmax(pred)
print(label, pred)
if label == pred:
count_correct += 1
if label == 0:
count_false += 1
if label == 1:
count_true += 1
if pred == 0:
predict_false += 1
if pred == 1:
predict_true += 1
print("False label: {}\nTrue label: {}\nCorrect rate: {}".format(count_false, count_true, count_correct/len(labels)))
print(predict_true, predict_false)
# 150.65.183.82
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment