Skip to content

Instantly share code, notes, and snippets.

@rabintang
Last active November 10, 2017 04:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rabintang/d726d9a2f19f6a14ba21f2f4f8e97843 to your computer and use it in GitHub Desktop.
Save rabintang/d726d9a2f19f6a14ba21f2f4f8e97843 to your computer and use it in GitHub Desktop.
torchtext dataset example
from torchtext import data
class DatasetDemo(data.Dataset):
def __init__(self, fields, path=None, examples=None, **kwargs):
if examples is None:
examples = []
for line in file(path):
parts = line.split("\t")
target = parts[0]
content = parts[1]
examples += [data.Example.fromlist([target, content], fields)]
super(DatasetDemo, self).__init__(examples, fields, **kwargs)
@classmethod
def splits(cls, fields, train_file, val_file, **kwargs):
trainset = cls(fields, train_file, **kwargs)
valset = cls(fields, val_file, **kwargs)
return (trainset, valset)
def load_data(**kwargs):
input_field = data.Field(lower = True)
# NOTE 注意Field的参数设置,具体见代码
target_field = data.Field(sequential = False, unk_token=None)
# field可以共用,text和target即为绑定到example身上的属性
fields = [("text", input_field), ("target", target_field)]
train_data, val_data = DatasetDemo.splits(fields, train_file, val_file, **kwargs)
# NOTE 如果不build_vocab,会直接
input_field.build_vocab(train_data, val_data)
target_field.build_vocab(train_data, val_data)
# NOTE 只能只用data.Iterator,不能用DataLoader
train_loader, val_loader = data.Iterator.splits(
(train_data, val_data),
batch_sizes=(batch_size, len(val_data)),
**kwargs)
return train_loader, val_loader
def train():
# NOTE device=-1表示通过cpu来加载数据
train_loader, val_loader = load_data(device=-1, repeat=False)
for batch in train_loader:
# 访问通过fields绑定在example上的属性
input, target = batch.text, batch.target
if use_cuda:
input, target = input.cuda(), target.cuda()
... ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment