Skip to content

Instantly share code, notes, and snippets.

@cyriltw
Last active August 23, 2022 14:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cyriltw/4e99ba8f8ce13348416e4125ffb601c7 to your computer and use it in GitHub Desktop.
Save cyriltw/4e99ba8f8ce13348416e4125ffb601c7 to your computer and use it in GitHub Desktop.
toxic_bert classification
import csv
import json
import os
import re
import datasets
import pdb
import string
import random
import gdown
# TODO: Add BibTeX citation
# Find for instance the citation on arxiv or on the dataset repo/website
_CITATION = """\
@InProceedings{huggingface:dataset,
title = {A great new dataset},
author={huggingface, Inc.
},
year={2020}
}
"""
# TODO: Add description of the dataset here
# You can copy an official description
_DESCRIPTION = """\
Comments Dataset
"""
# TODO: Add a link to an official homepage for the dataset here
_HOMEPAGE = ""
# TODO: Add the licence for the dataset here if you can find it
_LICENSE = ""
# TODO: Add link to the official dataset URLs here
# The HuggingFace Datasets library doesn't host the datasets but only points to the original files.
# This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
_URLS = {
"full_domain": "example.com",
}
# TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
class Dataset(datasets.GeneratorBasedBuilder):
"""TODO: Short description of my dataset."""
VERSION = datasets.Version("1.4.1")
# This is an example of a dataset with multiple configurations.
# If you don't want/need to define several sub-sets in your dataset,
# just remove the BUILDER_CONFIG_CLASS and the BUILDER_CONFIGS attributes.
# If you need to make complex sub-parts in the datasets with configurable options
# You can create your own builder configuration class to store attribute, inheriting from datasets.BuilderConfig
# BUILDER_CONFIG_CLASS = MyBuilderConfig
# You will be able to load one or the other configurations in the following list with
# data = datasets.load_dataset('my_dataset', 'first_domain')
# data = datasets.load_dataset('my_dataset', 'second_domain')
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="full_domain", version=VERSION, description="Complete Dataset"),
]
#DEFAULT_CONFIG_NAME = "first_domain" # It's not mandatory to have a default configuration. Just use one if it make sense.
def _info(self):
# TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
features = datasets.Features(
{
"comment_id": datasets.Value("string"),
"comment": datasets.Value("string"),
"processed":datasets.Value("string"),
# These are the features of your dataset like images, labels ...
}
)
# if self.config.name == "first_domain": # This is the name of the configuration selected in BUILDER_CONFIGS above
# features = datasets.Features(
# {
# "commentId": datasets.Value("string"),
# "text": datasets.Value("string"),
# # These are the features of your dataset like images, labels ...
# }
# )
# else: # This is an example to show how to have different features for "first_domain" and "second_domain"
# features = datasets.Features(
# {
# "sentence": datasets.Value("string"),
# "option2": datasets.Value("string"),
# "second_domain_answer": datasets.Value("string")
# # These are the features of your dataset like images, labels ...
# }
# )
return datasets.DatasetInfo(
# This is the description that will appear on the datasets page.
description=_DESCRIPTION,
# This defines the different columns of the dataset and their types
features=features, # Here we define them above because they are different between the two configurations
# If there's a common (input, target) tuple from the features, uncomment supervised_keys line below and
# specify them. They'll be used if as_supervised=True in builder.as_dataset.
# supervised_keys=("sentence", "label"),
# Homepage of the dataset for documentation
homepage=_HOMEPAGE,
# License for the dataset if available
license=_LICENSE,
# Citation for the dataset
citation=_CITATION,
)
def _split_generators(self, dl_manager):
# TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
# If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
# dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLS
# It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files.
# By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive
urls = _URLS[self.config.name] #"first_domain"]
# data_dir = dl_manager.download_and_extract(urls)
data_dir = "Dataset/"
if not os.path.exists(data_dir):
os.makedirs(data_dir)
output = "_comments_"+self.config.name+".tsv"
gdown.download(url=urls, output=os.path.join(data_dir, output), quiet=False, fuzzy=True)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, output),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, output),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, output),
"split": "validation",
},
),
# datasets.SplitGenerator(
# name=datasets.Split.VALIDATION,
# # These kwargs will be passed to _generate_examples
# gen_kwargs={
# "filepath": os.path.join(data_dir, "dev.jsonl"),
# "split": "dev",
# },
# ),
]
# def data_loader(folder_path):
# data_file = []
# with open(folder_path) as file:
# tsv_file = csv.reader(file, delimiter="\t")
# for line in tsv_file:
# data_file.append(line)
# return data_file
# method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
def _generate_examples(self, filepath, split):
# this function handles the examples and also chops down longer strings to 510
with open(filepath) as f:
tsv_file = csv.reader(f, delimiter="\t")
key = -1
for data in tsv_file:
# data = json.loads(row)
if data[0] == "commentId":
continue
key+=1
# if detect(data[1]) != "en":
# continue
if split == "test":
yield key, {
"comment_id": data[0],
"comment": data[1],
"processed":data[1],
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"],
}
if split == "validation":
if len(data[1]) > 510:
continue
else:
yield key, {
"comment_id": data[0],
"comment": data[1],
"processed":data[1][:510],
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"],
}
if split == "train":
data[1] = data[1].replace('\\n','')
data[1] = data[1].replace('\\xa0','')
data[1] = data[1].replace('\\r','')
# data[1] = re.sub(r"[^a-zA-Z0-9]+", ' ', data[1])
trimmed = data[1]
if len(data[1]) > 510:
data_rows = data[1].split('.')
for data_row in data_rows:
trimmed = data_row
key+=1
if len(data_row) > 510:
data_row = re.sub(r"[^a-zA-Z]+", ' ', data_row)
trimmed = data_row[:511]
yield key, {
"comment_id": data[0]+"_EXT"+str(key),
"comment": data[1],
"processed":trimmed,
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"],
}
else:
yield key, {
"comment_id": data[0],
"comment": data[1],
"processed":trimmed,
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"],
}
# Uncomment the following lines if you get an error loading Torch
# import os
# os.environ['OPENBLAS_NUM_THREADS'] = '1'
from torch import device
from transformers import pipeline
import csv
import pdb
from tqdm import tqdm
import pandas as pd
import argparse
import os
import re
from helper_function import create_folder
from datasets import load_dataset
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 for offline mode
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from transformers.pipelines.pt_utils import KeyDataset
# The model has a limit of 512 characters (no words), *** RuntimeError: The size of tensor a (1142) must match the size of tensor b (512) at non-singleton dimension 1
# Basic command to run; python classifier_toxicbert.py --input_raw_file "CNNDataset" --output_file "data/pred/toxic/small_cnn_news.pkl" --gpu_id 0
# Basic command to run; python classifier_toxicbert.py --input_raw_file "FOXDataset" --output_file "data/pred/fox_news.pkl" --gpu_id 0
# Basic command to run; python classifier_toxicbert.py --input_raw_file "MSNBCDataset" --output_file "data/pred/msncbc_news.pkl" --gpu_id 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_raw_file", help="Input messages raw annotator file w/o IDS" )
parser.add_argument("--input_raw_bin", help="The bin for the input, goes as binX",default=None)
parser.add_argument("--output_file",help="Path to write the output predictions")
parser.add_argument("--gpu_id",help="GPU ID",default=False)
args = parser.parse_args()
use_gpu = args.gpu_id
if use_gpu:
print(" > Using GPU ID {0}".format(use_gpu))
os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(use_gpu)
gpu_tag = 0
else:
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
gpu_tag = -1
current_path = args.input_raw_file
output_path = args.output_file
input_bin = args.input_raw_bin
create_folder(output_path)
# current_path = "data/news/CNN_comments_1.tsv"
# dataset = data_loader(current_path)
predictions = toxic_bert_predict_pipeline(current_path,gpu_tag,input_bin)
pd_preds = pd.DataFrame(predictions)
pd_preds.drop_duplicates(subset=['comment'],inplace=True)
pd_preds.to_pickle(output_path)
def toxic_bert_predict_pipeline(dataset_name,gpu_id,input_bin):
predictions = []
detoxify_pipeline = pipeline(
'text-classification',
model='unitary/toxic-bert',
tokenizer='bert-base-uncased',
function_to_apply='sigmoid',
device = gpu_id,
return_all_scores=True
)
# dataset = load_dataset(dataset_name,split="train")
if input_bin:
dataset = load_dataset(dataset_name,input_bin,split="validation")
else:
dataset = load_dataset(dataset_name,split="validation")
multiple_values = {}
for toxic_predictions,data_row in tqdm(zip(detoxify_pipeline(KeyDataset(dataset, "processed")),dataset)):
# from langdetect import detect
# if detect(dataset["comment"])!="en":
# pdb.set_trace()
# for toxic_predictions,data_row in tqdm(zip(detoxify_pipeline(KeyDataset(dataset, "processed")),dataset)):
del data_row['processed']
for predict in toxic_predictions:
data_row[predict['label']] = predict['score']
if "EXT" in data_row["comment_id"]:
# This loop is to handle a limitation of toxic BERT where character limit is 512
# model will go through all sentences and find the max
# if the comment is less than 512, it will not go to this loop
multiple_values[toxic_predictions[0]['score']] = data_row
continue
else:
if multiple_values:
max_score = multiple_values[max(multiple_values)]
comment_id = max_score['comment_id'].split('_EXT')
max_score['comment_id'] = comment_id[0]
predictions.append(max_score)
multiple_values = {}
continue
predictions.append(data_row)
return predictions
def data_loader(folder_path):
data_file = []
with open(folder_path) as file:
tsv_file = csv.reader(file, delimiter="\t")
for line in tsv_file:
data_file.append(line)
return data_file
if __name__=="__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment