Last active
November 10, 2017 04:30
-
-
Save rabintang/d726d9a2f19f6a14ba21f2f4f8e97843 to your computer and use it in GitHub Desktop.
torchtext dataset example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchtext import data | |
class DatasetDemo(data.Dataset): | |
def __init__(self, fields, path=None, examples=None, **kwargs): | |
if examples is None: | |
examples = [] | |
for line in file(path): | |
parts = line.split("\t") | |
target = parts[0] | |
content = parts[1] | |
examples += [data.Example.fromlist([target, content], fields)] | |
super(DatasetDemo, self).__init__(examples, fields, **kwargs) | |
@classmethod | |
def splits(cls, fields, train_file, val_file, **kwargs): | |
trainset = cls(fields, train_file, **kwargs) | |
valset = cls(fields, val_file, **kwargs) | |
return (trainset, valset) | |
def load_data(**kwargs): | |
input_field = data.Field(lower = True) | |
# NOTE 注意Field的参数设置,具体见代码 | |
target_field = data.Field(sequential = False, unk_token=None) | |
# field可以共用,text和target即为绑定到example身上的属性 | |
fields = [("text", input_field), ("target", target_field)] | |
train_data, val_data = DatasetDemo.splits(fields, train_file, val_file, **kwargs) | |
# NOTE 如果不build_vocab,会直接 | |
input_field.build_vocab(train_data, val_data) | |
target_field.build_vocab(train_data, val_data) | |
# NOTE 只能只用data.Iterator,不能用DataLoader | |
train_loader, val_loader = data.Iterator.splits( | |
(train_data, val_data), | |
batch_sizes=(batch_size, len(val_data)), | |
**kwargs) | |
return train_loader, val_loader | |
def train(): | |
# NOTE device=-1表示通过cpu来加载数据 | |
train_loader, val_loader = load_data(device=-1, repeat=False) | |
for batch in train_loader: | |
# 访问通过fields绑定在example上的属性 | |
input, target = batch.text, batch.target | |
if use_cuda: | |
input, target = input.cuda(), target.cuda() | |
... ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment