-
-
Save cyriltw/4e99ba8f8ce13348416e4125ffb601c7 to your computer and use it in GitHub Desktop.
toxic_bert classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import json | |
import os | |
import re | |
import datasets | |
import pdb | |
import string | |
import random | |
import gdown | |
# TODO: Add BibTeX citation | |
# Find for instance the citation on arxiv or on the dataset repo/website | |
_CITATION = """\ | |
@InProceedings{huggingface:dataset, | |
title = {A great new dataset}, | |
author={huggingface, Inc. | |
}, | |
year={2020} | |
} | |
""" | |
# TODO: Add description of the dataset here | |
# You can copy an official description | |
_DESCRIPTION = """\ | |
Comments Dataset | |
""" | |
# TODO: Add a link to an official homepage for the dataset here | |
_HOMEPAGE = "" | |
# TODO: Add the licence for the dataset here if you can find it | |
_LICENSE = "" | |
# TODO: Add link to the official dataset URLs here | |
# The HuggingFace Datasets library doesn't host the datasets but only points to the original files. | |
# This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method) | |
_URLS = { | |
"full_domain": "example.com", | |
} | |
# TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case | |
class Dataset(datasets.GeneratorBasedBuilder): | |
"""TODO: Short description of my dataset.""" | |
VERSION = datasets.Version("1.4.1") | |
# This is an example of a dataset with multiple configurations. | |
# If you don't want/need to define several sub-sets in your dataset, | |
# just remove the BUILDER_CONFIG_CLASS and the BUILDER_CONFIGS attributes. | |
# If you need to make complex sub-parts in the datasets with configurable options | |
# You can create your own builder configuration class to store attribute, inheriting from datasets.BuilderConfig | |
# BUILDER_CONFIG_CLASS = MyBuilderConfig | |
# You will be able to load one or the other configurations in the following list with | |
# data = datasets.load_dataset('my_dataset', 'first_domain') | |
# data = datasets.load_dataset('my_dataset', 'second_domain') | |
BUILDER_CONFIGS = [ | |
datasets.BuilderConfig(name="full_domain", version=VERSION, description="Complete Dataset"), | |
] | |
#DEFAULT_CONFIG_NAME = "first_domain" # It's not mandatory to have a default configuration. Just use one if it make sense. | |
def _info(self): | |
# TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset | |
features = datasets.Features( | |
{ | |
"comment_id": datasets.Value("string"), | |
"comment": datasets.Value("string"), | |
"processed":datasets.Value("string"), | |
# These are the features of your dataset like images, labels ... | |
} | |
) | |
# if self.config.name == "first_domain": # This is the name of the configuration selected in BUILDER_CONFIGS above | |
# features = datasets.Features( | |
# { | |
# "commentId": datasets.Value("string"), | |
# "text": datasets.Value("string"), | |
# # These are the features of your dataset like images, labels ... | |
# } | |
# ) | |
# else: # This is an example to show how to have different features for "first_domain" and "second_domain" | |
# features = datasets.Features( | |
# { | |
# "sentence": datasets.Value("string"), | |
# "option2": datasets.Value("string"), | |
# "second_domain_answer": datasets.Value("string") | |
# # These are the features of your dataset like images, labels ... | |
# } | |
# ) | |
return datasets.DatasetInfo( | |
# This is the description that will appear on the datasets page. | |
description=_DESCRIPTION, | |
# This defines the different columns of the dataset and their types | |
features=features, # Here we define them above because they are different between the two configurations | |
# If there's a common (input, target) tuple from the features, uncomment supervised_keys line below and | |
# specify them. They'll be used if as_supervised=True in builder.as_dataset. | |
# supervised_keys=("sentence", "label"), | |
# Homepage of the dataset for documentation | |
homepage=_HOMEPAGE, | |
# License for the dataset if available | |
license=_LICENSE, | |
# Citation for the dataset | |
citation=_CITATION, | |
) | |
def _split_generators(self, dl_manager): | |
# TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration | |
# If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name | |
# dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLS | |
# It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files. | |
# By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive | |
urls = _URLS[self.config.name] #"first_domain"] | |
# data_dir = dl_manager.download_and_extract(urls) | |
data_dir = "Dataset/" | |
if not os.path.exists(data_dir): | |
os.makedirs(data_dir) | |
output = "_comments_"+self.config.name+".tsv" | |
gdown.download(url=urls, output=os.path.join(data_dir, output), quiet=False, fuzzy=True) | |
return [ | |
datasets.SplitGenerator( | |
name=datasets.Split.TRAIN, | |
# These kwargs will be passed to _generate_examples | |
gen_kwargs={ | |
"filepath": os.path.join(data_dir, output), | |
"split": "train", | |
}, | |
), | |
datasets.SplitGenerator( | |
name=datasets.Split.TEST, | |
# These kwargs will be passed to _generate_examples | |
gen_kwargs={ | |
"filepath": os.path.join(data_dir, output), | |
"split": "test", | |
}, | |
), | |
datasets.SplitGenerator( | |
name=datasets.Split.VALIDATION, | |
# These kwargs will be passed to _generate_examples | |
gen_kwargs={ | |
"filepath": os.path.join(data_dir, output), | |
"split": "validation", | |
}, | |
), | |
# datasets.SplitGenerator( | |
# name=datasets.Split.VALIDATION, | |
# # These kwargs will be passed to _generate_examples | |
# gen_kwargs={ | |
# "filepath": os.path.join(data_dir, "dev.jsonl"), | |
# "split": "dev", | |
# }, | |
# ), | |
] | |
# def data_loader(folder_path): | |
# data_file = [] | |
# with open(folder_path) as file: | |
# tsv_file = csv.reader(file, delimiter="\t") | |
# for line in tsv_file: | |
# data_file.append(line) | |
# return data_file | |
# method parameters are unpacked from `gen_kwargs` as given in `_split_generators` | |
def _generate_examples(self, filepath, split): | |
# this function handles the examples and also chops down longer strings to 510 | |
with open(filepath) as f: | |
tsv_file = csv.reader(f, delimiter="\t") | |
key = -1 | |
for data in tsv_file: | |
# data = json.loads(row) | |
if data[0] == "commentId": | |
continue | |
key+=1 | |
# if detect(data[1]) != "en": | |
# continue | |
if split == "test": | |
yield key, { | |
"comment_id": data[0], | |
"comment": data[1], | |
"processed":data[1], | |
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"], | |
} | |
if split == "validation": | |
if len(data[1]) > 510: | |
continue | |
else: | |
yield key, { | |
"comment_id": data[0], | |
"comment": data[1], | |
"processed":data[1][:510], | |
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"], | |
} | |
if split == "train": | |
data[1] = data[1].replace('\\n','') | |
data[1] = data[1].replace('\\xa0','') | |
data[1] = data[1].replace('\\r','') | |
# data[1] = re.sub(r"[^a-zA-Z0-9]+", ' ', data[1]) | |
trimmed = data[1] | |
if len(data[1]) > 510: | |
data_rows = data[1].split('.') | |
for data_row in data_rows: | |
trimmed = data_row | |
key+=1 | |
if len(data_row) > 510: | |
data_row = re.sub(r"[^a-zA-Z]+", ' ', data_row) | |
trimmed = data_row[:511] | |
yield key, { | |
"comment_id": data[0]+"_EXT"+str(key), | |
"comment": data[1], | |
"processed":trimmed, | |
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"], | |
} | |
else: | |
yield key, { | |
"comment_id": data[0], | |
"comment": data[1], | |
"processed":trimmed, | |
# "second_domain_answer": "" if split == "test" else data["second_domain_answer"], | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Uncomment the following lines if you get an error loading Torch | |
# import os | |
# os.environ['OPENBLAS_NUM_THREADS'] = '1' | |
from torch import device | |
from transformers import pipeline | |
import csv | |
import pdb | |
from tqdm import tqdm | |
import pandas as pd | |
import argparse | |
import os | |
import re | |
from helper_function import create_folder | |
from datasets import load_dataset | |
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 for offline mode | |
import os | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
from transformers.pipelines.pt_utils import KeyDataset | |
# The model has a limit of 512 characters (no words), *** RuntimeError: The size of tensor a (1142) must match the size of tensor b (512) at non-singleton dimension 1 | |
# Basic command to run; python classifier_toxicbert.py --input_raw_file "CNNDataset" --output_file "data/pred/toxic/small_cnn_news.pkl" --gpu_id 0 | |
# Basic command to run; python classifier_toxicbert.py --input_raw_file "FOXDataset" --output_file "data/pred/fox_news.pkl" --gpu_id 0 | |
# Basic command to run; python classifier_toxicbert.py --input_raw_file "MSNBCDataset" --output_file "data/pred/msncbc_news.pkl" --gpu_id 0 | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_raw_file", help="Input messages raw annotator file w/o IDS" ) | |
parser.add_argument("--input_raw_bin", help="The bin for the input, goes as binX",default=None) | |
parser.add_argument("--output_file",help="Path to write the output predictions") | |
parser.add_argument("--gpu_id",help="GPU ID",default=False) | |
args = parser.parse_args() | |
use_gpu = args.gpu_id | |
if use_gpu: | |
print(" > Using GPU ID {0}".format(use_gpu)) | |
os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(use_gpu) | |
gpu_tag = 0 | |
else: | |
os.environ["CUDA_VISIBLE_DEVICES"]="-1" | |
gpu_tag = -1 | |
current_path = args.input_raw_file | |
output_path = args.output_file | |
input_bin = args.input_raw_bin | |
create_folder(output_path) | |
# current_path = "data/news/CNN_comments_1.tsv" | |
# dataset = data_loader(current_path) | |
predictions = toxic_bert_predict_pipeline(current_path,gpu_tag,input_bin) | |
pd_preds = pd.DataFrame(predictions) | |
pd_preds.drop_duplicates(subset=['comment'],inplace=True) | |
pd_preds.to_pickle(output_path) | |
def toxic_bert_predict_pipeline(dataset_name,gpu_id,input_bin): | |
predictions = [] | |
detoxify_pipeline = pipeline( | |
'text-classification', | |
model='unitary/toxic-bert', | |
tokenizer='bert-base-uncased', | |
function_to_apply='sigmoid', | |
device = gpu_id, | |
return_all_scores=True | |
) | |
# dataset = load_dataset(dataset_name,split="train") | |
if input_bin: | |
dataset = load_dataset(dataset_name,input_bin,split="validation") | |
else: | |
dataset = load_dataset(dataset_name,split="validation") | |
multiple_values = {} | |
for toxic_predictions,data_row in tqdm(zip(detoxify_pipeline(KeyDataset(dataset, "processed")),dataset)): | |
# from langdetect import detect | |
# if detect(dataset["comment"])!="en": | |
# pdb.set_trace() | |
# for toxic_predictions,data_row in tqdm(zip(detoxify_pipeline(KeyDataset(dataset, "processed")),dataset)): | |
del data_row['processed'] | |
for predict in toxic_predictions: | |
data_row[predict['label']] = predict['score'] | |
if "EXT" in data_row["comment_id"]: | |
# This loop is to handle a limitation of toxic BERT where character limit is 512 | |
# model will go through all sentences and find the max | |
# if the comment is less than 512, it will not go to this loop | |
multiple_values[toxic_predictions[0]['score']] = data_row | |
continue | |
else: | |
if multiple_values: | |
max_score = multiple_values[max(multiple_values)] | |
comment_id = max_score['comment_id'].split('_EXT') | |
max_score['comment_id'] = comment_id[0] | |
predictions.append(max_score) | |
multiple_values = {} | |
continue | |
predictions.append(data_row) | |
return predictions | |
def data_loader(folder_path): | |
data_file = [] | |
with open(folder_path) as file: | |
tsv_file = csv.reader(file, delimiter="\t") | |
for line in tsv_file: | |
data_file.append(line) | |
return data_file | |
if __name__=="__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment