Skip to content

Instantly share code, notes, and snippets.

@pommedeterresautee
Last active May 20, 2020 07:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pommedeterresautee/44e787f44f9d6821c1cb85c61adaeb64 to your computer and use it in GitHub Desktop.
Save pommedeterresautee/44e787f44f9d6821c1cb85c61adaeb64 to your computer and use it in GitHub Desktop.
Decrease Hugging Face Transformers training times by 2 - collator
# ...
def pad_seq(seq: List[int], max_batch_len: int, pad_value: int) -> List[int]:
# IRL, use pad_sequence
# https://pytorch.org/docs/master/generated/torch.nn.utils.rnn.pad_sequence.html
return seq + (max_batch_len - len(seq)) * [pad_value]
@dataclass
class SmartCollator(DataCollator):
pad_token_id: int
def collate_batch(self, batch: List[Features]) -> Dict[str, torch.Tensor]:
batch_inputs = list()
batch_attention_masks = list()
labels = list()
# find the max length of the mini batch
max_size = max([len(ex.input_ids) for ex in batch])
for item in batch:
# apply padding at the mini batch level
batch_inputs += [pad_seq(item.input_ids, max_size, self.pad_token_id)]
batch_attention_masks += [pad_seq(item.attention_mask, max_size, 0)]
labels.append(item.label)
# expected Transformers input format (dict of Tensors)
return {"input_ids": torch.tensor(batch_inputs, dtype=torch.long),
"attention_mask": torch.tensor(batch_attention_masks, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment