Created
January 24, 2018 04:05
-
-
Save ceshine/50a71e266722d0b7b00e2641fc86eb6f to your computer and use it in GitHub Desktop.
A torchtext example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import logging | |
import numpy as np | |
import pandas as pd | |
import spacy | |
import torch | |
from torchtext import data | |
NLP = spacy.load('en') | |
MAX_CHARS = 20000 | |
VAL_RATIO = 0.2 | |
LOGGER = logging.getLogger("toxic_dataset") | |
def tokenizer(comment): | |
comment = re.sub( | |
r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(comment)) | |
comment = re.sub(r"[ ]+", " ", comment) | |
comment = re.sub(r"\!+", "!", comment) | |
comment = re.sub(r"\,+", ",", comment) | |
comment = re.sub(r"\?+", "?", comment) | |
if (len(comment) > MAX_CHARS): | |
comment = comment[:MAX_CHARS] | |
return [x.text for x in NLP.tokenizer(comment) if x.text != " "] | |
def prepare_csv(seed=999): | |
df_train = pd.read_csv("data/train.csv") | |
df_train["comment_text"] = df_train.comment_text.str.replace("\n", " ") | |
idx = np.arange(df_train.shape[0]) | |
np.random.seed(seed) | |
np.random.shuffle(idx) | |
val_size = int(len(idx) * VAL_RATIO) | |
df_train.iloc[idx[val_size:], :].to_csv( | |
"cache/dataset_train.csv", index=False) | |
df_train.iloc[idx[:val_size], :].to_csv( | |
"cache/dataset_val.csv", index=False) | |
df_test = pd.read_csv("data/test.csv") | |
df_test["comment_text"] = df_test.comment_text.str.replace("\n", " ") | |
df_test.to_csv("cache/dataset_test.csv", index=False) | |
def get_dataset(fix_length=100, lower=False, vectors=None): | |
if vectors is not None: | |
# pretrain vectors only supports all lower cases | |
lower = True | |
LOGGER.debug("Preparing CSV files...") | |
prepare_csv() | |
comment = data.Field( | |
sequential=True, | |
fix_length=fix_length, | |
tokenize=tokenizer, | |
pad_first=True, | |
tensor_type=torch.cuda.LongTensor, | |
lower=lower | |
) | |
LOGGER.debug("Reading train csv file...") | |
train, val = data.TabularDataset.splits( | |
path='cache/', format='csv', skip_header=True, | |
train='dataset_train.csv', validation='dataset_val.csv', | |
fields=[ | |
('id', None), | |
('comment_text', comment), | |
('toxic', data.Field( | |
use_vocab=False, sequential=False, tensor_type=torch.cuda.ByteTensor)), | |
('severe_toxic', data.Field( | |
use_vocab=False, sequential=False, tensor_type=torch.cuda.ByteTensor)), | |
('obscene', data.Field( | |
use_vocab=False, sequential=False, tensor_type=torch.cuda.ByteTensor)), | |
('threat', data.Field( | |
use_vocab=False, sequential=False, tensor_type=torch.cuda.ByteTensor)), | |
('insult', data.Field( | |
use_vocab=False, sequential=False, tensor_type=torch.cuda.ByteTensor)), | |
('identity_hate', data.Field( | |
use_vocab=False, sequential=False, tensor_type=torch.cuda.ByteTensor)), | |
]) | |
LOGGER.debug("Reading test csv file...") | |
test = data.TabularDataset( | |
path='cache/dataset_test.csv', format='csv', skip_header=True, | |
fields=[ | |
('id', None), | |
('comment_text', comment) | |
]) | |
LOGGER.debug("Building vocabulary...") | |
comment.build_vocab( | |
train, val, test, | |
max_size=20000, | |
min_freq=50, | |
vectors=vectors | |
) | |
LOGGER.debug("Done preparing the datasets") | |
return train, val, test | |
def get_iterator(dataset, batch_size, train=True, shuffle=True, repeat=False): | |
dataset_iter = data.Iterator( | |
dataset, batch_size=batch_size, device=0, | |
train=train, shuffle=shuffle, repeat=repeat, | |
sort=False | |
) | |
return dataset_iter |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment