Created
June 14, 2022 21:23
-
-
Save lebrice/47ea38111a00cca13b5b71c30c70c7eb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
# Context: Dataset is on GPU memory. | |
from typing import Iterable | |
import torch | |
from torch import Tensor | |
from torchvision.datasets import MNIST | |
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset | |
class add_distractors(Dataset): | |
def __init__(self, dataset: TensorDataset, n_distractors: int): | |
"""Create a dataset where the sample `i` is the image at index `i` in `dataset`, | |
along with `n_distractors` other images. | |
""" | |
self.dataset = dataset | |
self.n_distractors = n_distractors | |
self._added_datasets = TensorDataset( | |
*(torch.concat([tensor, tensor]) for tensor in dataset.tensors) | |
) | |
def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: | |
# Concat the dataset with itself so we don't waste the last samples. The distractors | |
# just roll-over to the start of the dataset. | |
if index >= len(self.dataset): | |
raise IndexError(f"Index {index} is out of range.") | |
return ( | |
self._added_datasets[index : index + self.n_distractors + 1][0], | |
self._added_datasets[index][1], | |
) | |
def __len__(self) -> int: | |
return len(self.dataset) # type: ignore | |
def batching( | |
iterable: Iterable[tuple[Tensor, Tensor]], | |
batch_size: int, | |
drop_last: bool = False, | |
) -> Iterable[tuple[Tensor, Tensor]]: | |
"""Collects the elements from `iterable` into batches of size `batch_size`. | |
If `drop_last` is set, then the last | |
""" | |
current_batch_tensors = [] | |
for index, element in enumerate(iterable): | |
current_batch_tensors.append(element) | |
# If we have a full batch accumulated, yield it. | |
if len(current_batch_tensors) == batch_size: | |
xs, ys = zip(*current_batch_tensors) | |
assert set(x.shape for x in xs) == {xs[0].shape}, ( | |
index, | |
[x.shape for x in xs], | |
) | |
yield torch.stack(xs, dim=0), torch.stack(ys, dim=0) | |
current_batch_tensors.clear() | |
# If we have a non-empty batch left, yield it only if `drop_last` is False. | |
if current_batch_tensors and not ( | |
drop_last and len(current_batch_tensors) != batch_size | |
): | |
xs, ys = zip(*current_batch_tensors) | |
yield torch.stack(xs, dim=0), torch.stack(ys, dim=0) | |
def main(): | |
import tqdm | |
dataset = MNIST(root="data", train=True, download=True) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dataset = TensorDataset(dataset.data.to(device), dataset.targets.to(device)) | |
dataset = add_distractors(dataset, n_distractors=999) | |
n_samples = len(dataset) | |
assert n_samples == 60_000 | |
batch_size = 32 | |
# dataset = batching(iter(dataset), batch_size=batch_size) | |
dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
for i, (x, y) in enumerate(tqdm.tqdm(dataset, total=n_samples // batch_size)): | |
assert x.shape == (batch_size, 1000, 28, 28), (i, x.shape) | |
assert y.shape == (batch_size,), (i, y.shape) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment