Skip to content

Instantly share code, notes, and snippets.

@ceshine
Last active January 21, 2018 23:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceshine/58ea606359a384e9e00825a018e42b2e to your computer and use it in GitHub Desktop.
Save ceshine/58ea606359a384e9e00825a018e42b2e to your computer and use it in GitHub Desktop.
Prepare toxi comment dataset in fasttext format
import pandas as pd
import joblib
from sklearn.model_selection import train_test_split
LABELS = ["toxic", "severe_toxic", "obscene",
"threat", "insult", "identity_hate"]
EMPTY_ID = len(LABELS)
def create_labeled_string(row):
parts = [row["comment_text_cleaned"]]
flag = False
for i, l in enumerate(LABELS):
if row[l]:
parts.append("__label__{}".format(i))
flag = True
if flag is False:
parts.append("__label__{}".format(EMPTY_ID))
return " ".join(parts)
def main():
train = pd.read_csv('data/train.csv', usecols=[2, 3, 4, 5, 6, 7])
train_tokens = joblib.load("cache/train_tokenized.pkl")
train["comment_text_cleaned"] = [
" ".join([str(x) for x in tokens if str(x).strip() != ""]) for tokens in train_tokens]
train, val = train_test_split(train, test_size=0.25, random_state=24)
train_lines = train.apply(create_labeled_string, axis=1)
with open("cache/train.txt", "w") as f:
f.write("\n".join(train_lines))
val_lines = val.apply(create_labeled_string, axis=1)
with open("cache/val.txt", "w") as f:
f.write("\n".join(val_lines))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment