Skip to content

Instantly share code, notes, and snippets.

@naturale0
Forked from akurniawan/translation_char_example.py
Last active February 19, 2021 15:05
Show Gist options
  • Save naturale0/6bb3b8a5c682bd281de87e408fa71bf1 to your computer and use it in GitHub Desktop.
Save naturale0/6bb3b8a5c682bd281de87e408fa71bf1 to your computer and use it in GitHub Desktop.
import itertools
from torchtext.experimental.datasets import TextClassificationDataset
from torchtext.vocab import build_vocab_from_iterator
from torchtext.experimental.functional import sequential_transforms
from torchtext.experimental.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pack_padded_sequence
def build_char_vocab(data, index, bow="<w>", eow="</w>"):
"""
build character level vocabulary
"""
tok_list = [
[bow],
[eow],
]
for line in data:
tokens = list(itertools.chain.from_iterable(line[index]))
tok_list.append(tokens)
return build_vocab_from_iterator(tok_list)
def stoi(vocab):
"""
change string to index
"""
def func(tok_iter):
return [[vocab[char] if char != "≥" else 1 for char in word]\
for word in tok_iter]
return func
def tokenize_char(bow="<w>", eow="</w>", max_word_length=12):
"""
attach bow, eow token and pad with token
"""
def func(tok_iter):
result = np.empty((len(tok_iter), max_word_length+2), dtype=str)
# "≥" for padding
result[:] = [
[bow] + word + [eow] \
+ ["≥"] * (max_word_length - len(word)) \
if len(word) < max_word_length \
else [bow] + word[:max_word_length] + [eow]
for word in tok_iter]
return result
return func
if __name__ == "__main__":
tokenizer = get_tokenizer("spacy")
train, test = IMDB(tokenizer=tokenizer)
# Cache training data for vocabulary construction
train_data = [(line[0], [vocab.itos[ix] for ix in line[1]]) for line in train]
# Setup vocabularies (both words and chars)
char_vocab = build_char_vocab(train_data, index=1)
# Building the dataset with character level tokenization
def char_tokenizer(words):
return [list(word) for word in words]
char_transform = sequential_transforms(
char_tokenizer,
tokenize_char(),
stoi(char_vocab),
lambda x: torch.tensor(x)
)
trainset = TextClassificationDataset(
train_data,
char_vocab,
(lambda x: x, char_transform),
)
# Prepare DataLoader
def collate_fn(batch):
labels, text = zip(*batch)
text = torch.stack(text)
lens = list(map(lambda x: len(x[(x != 0).all(dim=1)]), text))
return (
torch.stack(labels),
pack_padded_sequence(text, lens, batch_first=True, enforce_sorted=False)
)
trainloader = data.DataLoader(trainset, batch_size=32, collate_fn=collate_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment