Skip to content

Instantly share code, notes, and snippets.

@EricFillion
Created February 6, 2021 18:19
Show Gist options
  • Save EricFillion/5b73d805ce96d8cc66e048ef7348db2d to your computer and use it in GitHub Desktop.
Save EricFillion/5b73d805ce96d8cc66e048ef7348db2d to your computer and use it in GitHub Desktop.
Fine-tuning a Hate Speech Detection Model
from happytransformer import HappyTextClassification
from datasets import load_dataset
import csv
# Colab: https://colab.research.google.com/drive/1z8m5QYi2tcQLd3m_fESK37SYbMBSqTO7?usp=sharing
def run_finetune():
train_csv_path = "train.csv"
eval_csv_path = "eval.csv"
train_dataset = load_dataset('tweets_hate_speech_detection', split='train[0:1999]')
eval_dataset = load_dataset('tweets_hate_speech_detection', split='train[2000:2499]')
generate_csv(train_csv_path, train_dataset)
generate_csv(eval_csv_path, eval_dataset)
happy_tc = HappyTextClassification(model_type="BERT", model_name="bert-base-uncased", num_labels=2)
before_loss = happy_tc.eval(eval_csv_path)
happy_tc.train(train_csv_path)
after_loss = happy_tc.eval(eval_csv_path)
print("Before loss: ", before_loss.loss)
print("After loss: ", after_loss.loss)
def generate_csv(csv_path, dataset):
with open(csv_path, 'w', newline='') as csvfile:
writter = csv.writer(csvfile)
writter.writerow(["text", "label"])
for case in dataset:
text = case["tweet"]
label = case["label"]
writter.writerow([text, label])
run_finetune()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment