Forked from akurniawan/translation_char_example.py
Last active
February 19, 2021 15:05
-
-
Save naturale0/6bb3b8a5c682bd281de87e408fa71bf1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from torchtext.experimental.datasets import TextClassificationDataset | |
from torchtext.vocab import build_vocab_from_iterator | |
from torchtext.experimental.functional import sequential_transforms | |
from torchtext.experimental.datasets import IMDB | |
from torchtext.data.utils import get_tokenizer | |
from torch.nn.utils.rnn import pack_padded_sequence | |
def build_char_vocab(data, index, bow="<w>", eow="</w>"): | |
""" | |
build character level vocabulary | |
""" | |
tok_list = [ | |
[bow], | |
[eow], | |
] | |
for line in data: | |
tokens = list(itertools.chain.from_iterable(line[index])) | |
tok_list.append(tokens) | |
return build_vocab_from_iterator(tok_list) | |
def stoi(vocab): | |
""" | |
change string to index | |
""" | |
def func(tok_iter): | |
return [[vocab[char] if char != "≥" else 1 for char in word]\ | |
for word in tok_iter] | |
return func | |
def tokenize_char(bow="<w>", eow="</w>", max_word_length=12): | |
""" | |
attach bow, eow token and pad with token | |
""" | |
def func(tok_iter): | |
result = np.empty((len(tok_iter), max_word_length+2), dtype=str) | |
# "≥" for padding | |
result[:] = [ | |
[bow] + word + [eow] \ | |
+ ["≥"] * (max_word_length - len(word)) \ | |
if len(word) < max_word_length \ | |
else [bow] + word[:max_word_length] + [eow] | |
for word in tok_iter] | |
return result | |
return func | |
if __name__ == "__main__": | |
tokenizer = get_tokenizer("spacy") | |
train, test = IMDB(tokenizer=tokenizer) | |
# Cache training data for vocabulary construction | |
train_data = [(line[0], [vocab.itos[ix] for ix in line[1]]) for line in train] | |
# Setup vocabularies (both words and chars) | |
char_vocab = build_char_vocab(train_data, index=1) | |
# Building the dataset with character level tokenization | |
def char_tokenizer(words): | |
return [list(word) for word in words] | |
char_transform = sequential_transforms( | |
char_tokenizer, | |
tokenize_char(), | |
stoi(char_vocab), | |
lambda x: torch.tensor(x) | |
) | |
trainset = TextClassificationDataset( | |
train_data, | |
char_vocab, | |
(lambda x: x, char_transform), | |
) | |
# Prepare DataLoader | |
def collate_fn(batch): | |
labels, text = zip(*batch) | |
text = torch.stack(text) | |
lens = list(map(lambda x: len(x[(x != 0).all(dim=1)]), text)) | |
return ( | |
torch.stack(labels), | |
pack_padded_sequence(text, lens, batch_first=True, enforce_sorted=False) | |
) | |
trainloader = data.DataLoader(trainset, batch_size=32, collate_fn=collate_fn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment