Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created October 23, 2018 23:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ssnl/205a4cd2e4e631a42cc9d8e879a296dc to your computer and use it in GitHub Desktop.
Save ssnl/205a4cd2e4e631a42cc9d8e879a296dc to your computer and use it in GitHub Desktop.
import torch
import torch.utils.data
class ChunkDataset(object):
def __init__(self, chunk_lengths):
self.chunk_lengths = tuple(chunk_lengths)
def __getitem__(self, key):
chunk_idx, indices_in_chunk = key
return self.get_data_from_chunk(chunk_idx, indices_in_chunk)
def get_data_from_chunk(self, chunk_idx, indices_in_chunk):
# dummy impl
# IRL, this can be querying a database
return indices_in_chunk.to(torch.float).add_((chunk_idx.item() + 1) * 1000)
class Sampler(torch.utils.data.Sampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
def __iter__(self):
n_chunks = len(self.dataset.chunk_lengths)
chunk_indices = [torch.randperm(n) for n in self.dataset.chunk_lengths]
chunk_remaining = torch.tensor(self.dataset.chunk_lengths) # copies
total_remaining = sum(chunk_remaining)
while total_remaining > 0:
chunk_idx = torch.multinomial(chunk_remaining.float(), 1)
chunk_avail = chunk_remaining[chunk_idx].item()
n_samples = min(self.batch_size, chunk_avail)
yield chunk_idx, chunk_indices[chunk_idx][chunk_avail - n_samples : chunk_avail]
chunk_remaining[chunk_idx] -= n_samples
total_remaining -= n_samples
def get_sampler(self, batch_size):
return ChunkDataset.Sampler(self, batch_size)
@staticmethod
def collate_fn(samples_collection):
# we do our own batching, so `samples_collection` is a list of one element
# return that element
assert isinstance(samples_collection, (list, tuple)) and len(samples_collection) == 1
return samples_collection[0]
dataset = ChunkDataset(chunk_lengths=[1, 3, 5, 10])
sampler = dataset.get_sampler(batch_size=3)
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, num_workers=2, collate_fn=dataset.collate_fn)
for d in loader:
print(d)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment