Skip to content

Instantly share code, notes, and snippets.

@erip
Last active January 2, 2022 22:27
Show Gist options
  • Save erip/81d2816f71ba2e95668095e5a1e1040e to your computer and use it in GitHub Desktop.
Save erip/81d2816f71ba2e95668095e5a1e1040e to your computer and use it in GitHub Desktop.
A PyTorch Sampler which samples batches containing no more than max_tokens post-pad tokens.
#!/usr/bin/env python3
import random
from typing import List, Optional
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Sampler, DataLoader, Dataset
class DumbDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, i):
return torch.tensor(self.data[i])
def __len__(self):
return len(self.data)
class MaxTokenPerBatchSampler(Sampler):
def __init__(self, lengths: List[int], max_tokens: int, max_length: Optional[int]=None, min_length: int = 1, padding_ratio: float = 0.1, shuffle: bool=True):
if max_length is None:
max_length = float('inf')
assert 0 < min_length < max_length, "min_length must be positive and less than max_length"
if not shuffle:
padding_ratio = 0.0
lengths = [length for length in lengths if min_length <= length <= max_length]
noisy_lengths = [self._generate_noise(length, padding_ratio) for length in lengths]
self.sorted_length_idx = torch.argsort(torch.tensor(noisy_lengths), descending=True).tolist()
self.lengths = [lengths[i] for i in self.sorted_length_idx]
self.max_length = max_length
self.min_length = min_length
self.max_tokens = max_tokens
self.shuffle = shuffle
@staticmethod
def _generate_noise(val: int, ratio: float):
# Inspired by allennlp
noise_value = val * ratio
noise = random.uniform(-noise_value, noise_value)
return val + noise
def __iter__(self):
curr_max_length = 0
batch = []
for idx, length in zip(self.sorted_length_idx, self.lengths):
# If this length is bigger than curr max length, we need to allocate
# more padding elements
max_batch_size = max(length, curr_max_length) * (len(batch) + 1)
# If adding this element would overflow us, start new batch
if max_batch_size > self.max_tokens:
if self.shuffle:
random.shuffle(batch)
yield batch
curr_max_length = 0
batch = []
curr_max_length = max(curr_max_length, length)
batch.append(idx)
if len(batch) != 0:
if self.shuffle:
random.shuffle(batch)
yield batch
def collate_fn(batch):
return pad_sequence(batch, batch_first=True, padding_value=-1)
if __name__ == "__main__":
ds = DumbDataset([[i]*i for i in range(1, 20)])
lengths = [len(e) for e in ds]
batch_sampler=MaxTokenPerBatchSampler(lengths, max_tokens=4096, shuffle=False)
loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=collate_fn)
for i, batch in enumerate(iter(loader)):
print(f"Batch {i}: {batch}: {batch.numel()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment