Last active
January 2, 2022 22:27
-
-
Save erip/81d2816f71ba2e95668095e5a1e1040e to your computer and use it in GitHub Desktop.
A PyTorch Sampler which samples batches containing no more than max_tokens post-pad tokens.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import random | |
from typing import List, Optional | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import Sampler, DataLoader, Dataset | |
class DumbDataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __getitem__(self, i): | |
return torch.tensor(self.data[i]) | |
def __len__(self): | |
return len(self.data) | |
class MaxTokenPerBatchSampler(Sampler): | |
def __init__(self, lengths: List[int], max_tokens: int, max_length: Optional[int]=None, min_length: int = 1, padding_ratio: float = 0.1, shuffle: bool=True): | |
if max_length is None: | |
max_length = float('inf') | |
assert 0 < min_length < max_length, "min_length must be positive and less than max_length" | |
if not shuffle: | |
padding_ratio = 0.0 | |
lengths = [length for length in lengths if min_length <= length <= max_length] | |
noisy_lengths = [self._generate_noise(length, padding_ratio) for length in lengths] | |
self.sorted_length_idx = torch.argsort(torch.tensor(noisy_lengths), descending=True).tolist() | |
self.lengths = [lengths[i] for i in self.sorted_length_idx] | |
self.max_length = max_length | |
self.min_length = min_length | |
self.max_tokens = max_tokens | |
self.shuffle = shuffle | |
@staticmethod | |
def _generate_noise(val: int, ratio: float): | |
# Inspired by allennlp | |
noise_value = val * ratio | |
noise = random.uniform(-noise_value, noise_value) | |
return val + noise | |
def __iter__(self): | |
curr_max_length = 0 | |
batch = [] | |
for idx, length in zip(self.sorted_length_idx, self.lengths): | |
# If this length is bigger than curr max length, we need to allocate | |
# more padding elements | |
max_batch_size = max(length, curr_max_length) * (len(batch) + 1) | |
# If adding this element would overflow us, start new batch | |
if max_batch_size > self.max_tokens: | |
if self.shuffle: | |
random.shuffle(batch) | |
yield batch | |
curr_max_length = 0 | |
batch = [] | |
curr_max_length = max(curr_max_length, length) | |
batch.append(idx) | |
if len(batch) != 0: | |
if self.shuffle: | |
random.shuffle(batch) | |
yield batch | |
def collate_fn(batch): | |
return pad_sequence(batch, batch_first=True, padding_value=-1) | |
if __name__ == "__main__": | |
ds = DumbDataset([[i]*i for i in range(1, 20)]) | |
lengths = [len(e) for e in ds] | |
batch_sampler=MaxTokenPerBatchSampler(lengths, max_tokens=4096, shuffle=False) | |
loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=collate_fn) | |
for i, batch in enumerate(iter(loader)): | |
print(f"Batch {i}: {batch}: {batch.numel()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment