Created
February 15, 2020 15:02
-
-
Save henry16lin/53185c52e87205d42434c8041a58547d to your computer and use it in GitHub Desktop.
datalaoder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader | |
from torch.nn.utils.rnn import pad_sequence | |
"""" | |
create_mini_batch(samples)吃上面定義的mydataset | |
回傳訓練 BERT 時會需要的 4 個 tensors: | |
- tokens_tensors : (batch_size, max_seq_len_in_batch) | |
- segments_tensors: (batch_size, max_seq_len_in_batch) | |
- masks_tensors : (batch_size, max_seq_len_in_batch) | |
- label_ids : (batch_size) | |
""" | |
#collate_fn: 如何將多個樣本的資料連成一個batch丟進 model | |
#截長補短後要限制attention只注意非pad 的部分 | |
def create_mini_batch(samples): | |
tokens_tensors = [s[0] for s in samples] | |
segments_tensors = [s[1] for s in samples] | |
# 訓練集有 labels | |
if samples[0][2] is not None: | |
label_ids = torch.stack([s[2] for s in samples]) | |
else: | |
label_ids = None | |
# zero pad到該batch下最長的長度 | |
tokens_tensors = pad_sequence(tokens_tensors, batch_first=True) | |
segments_tensors = pad_sequence(segments_tensors,batch_first=True) | |
# attention masks,將 tokens_tensors 裡頭不為 zero padding | |
# 的位置設為 1 讓 BERT 只關注這些位置的 tokens | |
masks_tensors = torch.zeros(tokens_tensors.shape,dtype=torch.long) | |
masks_tensors = masks_tensors.masked_fill(tokens_tensors != 0, 1) | |
return tokens_tensors, segments_tensors, masks_tensors, label_ids | |
# 初始化一個每次回傳 batch size 個訓練樣本的 DataLoader | |
# 利用 'collate_fn' 將 list of samples 合併成一個 mini-batch | |
BATCH_SIZE = 16 | |
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE,collate_fn=create_mini_batch,shuffle=True) | |
valloader = DataLoader(valset, batch_size=BATCH_SIZE,collate_fn=create_mini_batch,shuffle=False) | |
testloader = DataLoader(testset, batch_size=BATCH_SIZE,collate_fn=create_mini_batch,shuffle=False) | |
data = next(iter(trainloader)) | |
tokens_tensors, segments_tensors, masks_tensors, label_ids = data | |
print(tokens_tensors) | |
print(segments_tensors) | |
print(masks_tensors) | |
print(label_ids) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment