Skip to content

Instantly share code, notes, and snippets.

@zhpmatrix
Created March 8, 2019 10:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhpmatrix/d2a944e1b5b7416131dc76cd884a78b2 to your computer and use it in GitHub Desktop.
Save zhpmatrix/d2a944e1b5b7416131dc76cd884a78b2 to your computer and use it in GitHub Desktop.
"""
http://www.nlpuser.com/pytorch/2018/10/30/useTorchText/
http://anie.me/On-Torchtext/
"""
import pandas as pd
from torchtext import data
def get_dataset(data_, text_field, label_field, test=False):
fields = [('id',None),('comment',text_field),('label', label_field)]
examples = []
if test:
for text in data_['comment']:
examples.append(data.Example.fromlist([None, text, None],fields))
else:
for text, label in (zip(data_['comment'], data_['label'])):
examples.append(data.Example.fromlist([None, text, label],fields))
return examples, fields
def get_data_iter(data):
"""
simulation function with yield
"""
for i in range(len(data)):
yield data[i]
if __name__ == '__main__':
# data = [1,3,4,5,5,6,6]
# data_iter = get_data_iter(data)
# for idx, batch in enumerate(data_iter):
# sample = batch
# print(sample)
train = pd.read_csv('data/train.csv', sep='\t')
test = pd.read_csv('data/test.csv', sep='\t')
tokenizer = lambda x: x.split()
TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True)
LABEL = data.Field(sequential=False, use_vocab=False)
train_examples, train_fields = get_dataset(train,TEXT, LABEL)
test_examples, test_fields = get_dataset(test,TEXT, None, True)
train_ = data.Dataset(train_examples, train_fields)
test_ = data.Dataset(test_examples, test_fields)
TEXT.build_vocab(train_)
train_batches = data.BucketIterator(train_, batch_size=3, device=-1, sort_key=lambda x: len(x.comment), sort_within_batch=True, repeat=False)
test_batches = data.Iterator(test_, batch_size=4, device=-1, sort=False, repeat=False)
for idx, batch in enumerate(train_batches):
comment, label = batch.comment, batch.label
print(comment.shape, label.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment