Created
October 23, 2018 23:06
-
-
Save ssnl/205a4cd2e4e631a42cc9d8e879a296dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.data | |
class ChunkDataset(object): | |
def __init__(self, chunk_lengths): | |
self.chunk_lengths = tuple(chunk_lengths) | |
def __getitem__(self, key): | |
chunk_idx, indices_in_chunk = key | |
return self.get_data_from_chunk(chunk_idx, indices_in_chunk) | |
def get_data_from_chunk(self, chunk_idx, indices_in_chunk): | |
# dummy impl | |
# IRL, this can be querying a database | |
return indices_in_chunk.to(torch.float).add_((chunk_idx.item() + 1) * 1000) | |
class Sampler(torch.utils.data.Sampler): | |
def __init__(self, dataset, batch_size): | |
self.dataset = dataset | |
self.batch_size = batch_size | |
def __iter__(self): | |
n_chunks = len(self.dataset.chunk_lengths) | |
chunk_indices = [torch.randperm(n) for n in self.dataset.chunk_lengths] | |
chunk_remaining = torch.tensor(self.dataset.chunk_lengths) # copies | |
total_remaining = sum(chunk_remaining) | |
while total_remaining > 0: | |
chunk_idx = torch.multinomial(chunk_remaining.float(), 1) | |
chunk_avail = chunk_remaining[chunk_idx].item() | |
n_samples = min(self.batch_size, chunk_avail) | |
yield chunk_idx, chunk_indices[chunk_idx][chunk_avail - n_samples : chunk_avail] | |
chunk_remaining[chunk_idx] -= n_samples | |
total_remaining -= n_samples | |
def get_sampler(self, batch_size): | |
return ChunkDataset.Sampler(self, batch_size) | |
@staticmethod | |
def collate_fn(samples_collection): | |
# we do our own batching, so `samples_collection` is a list of one element | |
# return that element | |
assert isinstance(samples_collection, (list, tuple)) and len(samples_collection) == 1 | |
return samples_collection[0] | |
dataset = ChunkDataset(chunk_lengths=[1, 3, 5, 10]) | |
sampler = dataset.get_sampler(batch_size=3) | |
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, num_workers=2, collate_fn=dataset.collate_fn) | |
for d in loader: | |
print(d) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment