Skip to content

Instantly share code, notes, and snippets.

@henry16lin
Created February 15, 2020 15:02
Show Gist options
  • Save henry16lin/53185c52e87205d42434c8041a58547d to your computer and use it in GitHub Desktop.
Save henry16lin/53185c52e87205d42434c8041a58547d to your computer and use it in GitHub Desktop.
datalaoder
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
""""
create_mini_batch(samples)吃上面定義的mydataset
回傳訓練 BERT 時會需要的 4 個 tensors:
- tokens_tensors : (batch_size, max_seq_len_in_batch)
- segments_tensors: (batch_size, max_seq_len_in_batch)
- masks_tensors : (batch_size, max_seq_len_in_batch)
- label_ids : (batch_size)
"""
#collate_fn: 如何將多個樣本的資料連成一個batch丟進 model
#截長補短後要限制attention只注意非pad 的部分
def create_mini_batch(samples):
tokens_tensors = [s[0] for s in samples]
segments_tensors = [s[1] for s in samples]
# 訓練集有 labels
if samples[0][2] is not None:
label_ids = torch.stack([s[2] for s in samples])
else:
label_ids = None
# zero pad到該batch下最長的長度
tokens_tensors = pad_sequence(tokens_tensors, batch_first=True)
segments_tensors = pad_sequence(segments_tensors,batch_first=True)
# attention masks,將 tokens_tensors 裡頭不為 zero padding
# 的位置設為 1 讓 BERT 只關注這些位置的 tokens
masks_tensors = torch.zeros(tokens_tensors.shape,dtype=torch.long)
masks_tensors = masks_tensors.masked_fill(tokens_tensors != 0, 1)
return tokens_tensors, segments_tensors, masks_tensors, label_ids
# 初始化一個每次回傳 batch size 個訓練樣本的 DataLoader
# 利用 'collate_fn' 將 list of samples 合併成一個 mini-batch
BATCH_SIZE = 16
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE,collate_fn=create_mini_batch,shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE,collate_fn=create_mini_batch,shuffle=False)
testloader = DataLoader(testset, batch_size=BATCH_SIZE,collate_fn=create_mini_batch,shuffle=False)
data = next(iter(trainloader))
tokens_tensors, segments_tensors, masks_tensors, label_ids = data
print(tokens_tensors)
print(segments_tensors)
print(masks_tensors)
print(label_ids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment