Skip to content

Instantly share code, notes, and snippets.

@greed2411
Created June 18, 2018 13:27
Show Gist options
  • Save greed2411/b9542a557f343e3c47188d38eb96accd to your computer and use it in GitHub Desktop.
Save greed2411/b9542a557f343e3c47188d38eb96accd to your computer and use it in GitHub Desktop.
torch text
name age
bob 47
richard 48
# credits to: bentrevett
import torch
import torchtext
from torchtext import data
str2int = lambda x : int(x)
NAME = data.Field()
AGE = data.LabelField(preprocessing=str2int)
fields = [('n', NAME), ('a', AGE)]
train, test = data.TabularDataset.splits(
path = '',
train = 'data.csv',
test = 'data.csv',
format = 'csv',
fields = fields,
skip_header = True,
)
print(vars(train[0]))
# output
# (torch) λ python torchtext_str2int_labelfield.py
# {'n': ['bob'], 'a': 47}
# (torch) λ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment