Skip to content

Instantly share code, notes, and snippets.

@itsuncheng
Created June 12, 2020 10:25
Show Gist options
  • Save itsuncheng/cac95d1e0329f132d6393530448cc1f8 to your computer and use it in GitHub Desktop.
Save itsuncheng/cac95d1e0329f132d6393530448cc1f8 to your computer and use it in GitHub Desktop.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Model parameter
MAX_SEQ_LEN = 128
PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
# Fields
label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True,
fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX)
fields = [('label', label_field), ('title', text_field), ('text', text_field), ('titletext', text_field)]
# TabularDataset
train, valid, test = TabularDataset.splits(path=source_folder, train='train.csv', validation='valid.csv',
test='test.csv', format='CSV', fields=fields, skip_header=True)
# Iterators
train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text),
device=device, train=True, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text),
device=device, train=True, sort=True, sort_within_batch=True)
test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment