Skip to content

Instantly share code, notes, and snippets.

@lebrice
Created June 14, 2022 21:23
Show Gist options
  • Save lebrice/47ea38111a00cca13b5b71c30c70c7eb to your computer and use it in GitHub Desktop.
Save lebrice/47ea38111a00cca13b5b71c30c70c7eb to your computer and use it in GitHub Desktop.
from __future__ import annotations
# Context: Dataset is on GPU memory.
from typing import Iterable
import torch
from torch import Tensor
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset
class add_distractors(Dataset):
def __init__(self, dataset: TensorDataset, n_distractors: int):
"""Create a dataset where the sample `i` is the image at index `i` in `dataset`,
along with `n_distractors` other images.
"""
self.dataset = dataset
self.n_distractors = n_distractors
self._added_datasets = TensorDataset(
*(torch.concat([tensor, tensor]) for tensor in dataset.tensors)
)
def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
# Concat the dataset with itself so we don't waste the last samples. The distractors
# just roll-over to the start of the dataset.
if index >= len(self.dataset):
raise IndexError(f"Index {index} is out of range.")
return (
self._added_datasets[index : index + self.n_distractors + 1][0],
self._added_datasets[index][1],
)
def __len__(self) -> int:
return len(self.dataset) # type: ignore
def batching(
iterable: Iterable[tuple[Tensor, Tensor]],
batch_size: int,
drop_last: bool = False,
) -> Iterable[tuple[Tensor, Tensor]]:
"""Collects the elements from `iterable` into batches of size `batch_size`.
If `drop_last` is set, then the last
"""
current_batch_tensors = []
for index, element in enumerate(iterable):
current_batch_tensors.append(element)
# If we have a full batch accumulated, yield it.
if len(current_batch_tensors) == batch_size:
xs, ys = zip(*current_batch_tensors)
assert set(x.shape for x in xs) == {xs[0].shape}, (
index,
[x.shape for x in xs],
)
yield torch.stack(xs, dim=0), torch.stack(ys, dim=0)
current_batch_tensors.clear()
# If we have a non-empty batch left, yield it only if `drop_last` is False.
if current_batch_tensors and not (
drop_last and len(current_batch_tensors) != batch_size
):
xs, ys = zip(*current_batch_tensors)
yield torch.stack(xs, dim=0), torch.stack(ys, dim=0)
def main():
import tqdm
dataset = MNIST(root="data", train=True, download=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = TensorDataset(dataset.data.to(device), dataset.targets.to(device))
dataset = add_distractors(dataset, n_distractors=999)
n_samples = len(dataset)
assert n_samples == 60_000
batch_size = 32
# dataset = batching(iter(dataset), batch_size=batch_size)
dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
for i, (x, y) in enumerate(tqdm.tqdm(dataset, total=n_samples // batch_size)):
assert x.shape == (batch_size, 1000, 28, 28), (i, x.shape)
assert y.shape == (batch_size,), (i, y.shape)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment