Created
February 6, 2021 18:19
-
-
Save EricFillion/5b73d805ce96d8cc66e048ef7348db2d to your computer and use it in GitHub Desktop.
Fine-tuning a Hate Speech Detection Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from happytransformer import HappyTextClassification | |
from datasets import load_dataset | |
import csv | |
# Colab: https://colab.research.google.com/drive/1z8m5QYi2tcQLd3m_fESK37SYbMBSqTO7?usp=sharing | |
def run_finetune(): | |
train_csv_path = "train.csv" | |
eval_csv_path = "eval.csv" | |
train_dataset = load_dataset('tweets_hate_speech_detection', split='train[0:1999]') | |
eval_dataset = load_dataset('tweets_hate_speech_detection', split='train[2000:2499]') | |
generate_csv(train_csv_path, train_dataset) | |
generate_csv(eval_csv_path, eval_dataset) | |
happy_tc = HappyTextClassification(model_type="BERT", model_name="bert-base-uncased", num_labels=2) | |
before_loss = happy_tc.eval(eval_csv_path) | |
happy_tc.train(train_csv_path) | |
after_loss = happy_tc.eval(eval_csv_path) | |
print("Before loss: ", before_loss.loss) | |
print("After loss: ", after_loss.loss) | |
def generate_csv(csv_path, dataset): | |
with open(csv_path, 'w', newline='') as csvfile: | |
writter = csv.writer(csvfile) | |
writter.writerow(["text", "label"]) | |
for case in dataset: | |
text = case["tweet"] | |
label = case["label"] | |
writter.writerow([text, label]) | |
run_finetune() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment