Skip to content

Instantly share code, notes, and snippets.

@MischaPanch
Created August 10, 2023 17:42
Show Gist options
  • Save MischaPanch/ff790dd72ebb241006e186f8d1a26e3d to your computer and use it in GitHub Desktop.
Save MischaPanch/ff790dd72ebb241006e186f8d1a26e3d to your computer and use it in GitHub Desktop.
Accelerated torch dataset and dataloader
import numpy as np
import torch
from torch.utils.data import Dataset, Subset, TensorDataset
from typing import (
Callable,
Iterator,
Literal,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
def get_batch_boundaries(
batch_size: int,
len_data: int,
last_batch: Literal["drop", "merge", "keep"] = "merge",
):
"""Get the boundaries of batches for a given batch size and data length.
:param batch_size: the size of each batch
:param len_data: the length of the data
:param last_batch: one of "drop", "merge", or "keep".
- "drop": drop the last batch if it is smaller than batch_size
- "merge": merge the last batch with the previous batch
- "keep": keep the last batch as is, even if it is smaller than batch_size
:return: a numpy array of batch boundaries
"""
if batch_size >= len_data:
return np.array([0, len_data])
batch_boundaries = np.arange(0, len_data + 1, batch_size)
if len_data % batch_size == 0 or last_batch == "drop":
return batch_boundaries
elif last_batch == "merge":
batch_boundaries[-1] = len_data
elif last_batch == "keep":
batch_boundaries = np.append(batch_boundaries, len_data)
else:
raise ValueError(
f"last_batch must be one of 'drop', 'merge', or 'keep', "
f"but got {last_batch}"
)
return batch_boundaries
class Accelerated2DTensorDataset(Dataset):
"""
Same logic as torch.utils.data.TensorDataset but avoids some overhead by
retrieving from a single tensor and then slicing it, instead of retrieving from
multiple tensors in a loop. Currently only supports up-to 2D tensors.
"""
def __init__(self, *tensors: torch.Tensor) -> None:
"""
:param tensors: tensors to be stacked. All tensors must have the same
first dimension, and up to 2 dimensions.
"""
self._len = tensors[0].shape[0]
boundaries = [0]
unsqueezed_tensors = []
for tensor in tensors:
if tensor.shape[0] != self._len:
raise ValueError(
"All tensors must have the same first dimension, "
f"but got {tensor.shape[0]} and {self._len}"
)
if len(tensor.shape) == 1:
tensor = tensor.unsqueeze(1)
if len(tensor.shape) != 2:
raise ValueError(
"All tensors must have up to 2 dimensions, "
f"but got {len(tensor.shape)}"
)
unsqueezed_tensors.append(tensor)
boundaries.append(boundaries[-1] + tensor.shape[1])
self._stacked_tensors = torch.hstack(unsqueezed_tensors)
self._slices = [
slice(low, high) for low, high in zip(boundaries[:-1], boundaries[1:])
]
def __getitem__(self, index) -> Tuple[torch.Tensor, ...]:
return tuple(self._stacked_tensors[index, sl] for sl in self._slices)
def __len__(self) -> int:
return self._len
SupportsBatching = Union[
TensorDataset,
Accelerated2DTensorDataset,
Subset,
torch.Tensor,
np.ndarray,
Sequence,
]
T = TypeVar("T")
class BatchDataLoader:
def __init__(
self,
data: SupportsBatching,
batch_size: int,
shuffle: bool = False,
last_batch: Literal["drop", "merge", "keep"] = "merge",
collate_fn: Optional[Callable[[Union[torch.Tensor, np.ndarray]], T]] = None,
) -> None:
"""A simple data loader that returns batches of data.
:param data: the data to be loaded. If tensor-based, the batches will be
tensors, otherwise they will be numpy arrays.
:param batch_size: the size of each batch
:param shuffle: whether to shuffle the data before batching
:param last_batch: one of "drop", "merge", or "keep".
- "drop": drop the last batch if it is smaller than batch_size
- "merge": merge the last batch with the previous batch
- "keep": keep the last batch as is, even if it is smaller than batch_size
:param collate_fn: a function to apply to each batch before returning it
"""
if isinstance(data, Sequence):
data = np.array(data)
# not pretty nor robust, but hopefully this code won't be around for long anyway
while isinstance(data, Subset):
data = data.dataset
self._data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.last_batch = last_batch
self._boundary_idxs = get_batch_boundaries(
batch_size, len(data), last_batch=last_batch
)
self._num_batches = len(self._boundary_idxs) - 1
self.collate_fn = collate_fn or (lambda x: x)
# TODO: the generic annotation here is probably incorrect
def __iter__(self) -> Iterator[Union[np.ndarray, torch.Tensor, T]]:
if self.shuffle:
self._shuffle_data()
for lower, upper in zip(self._boundary_idxs[:-1], self._boundary_idxs[1:]):
yield self.collate_fn(self._data[lower:upper])
def _shuffle_data(self):
data_type = type(self._data)
self._data = self._data[np.random.permutation(len(self._data))]
if issubclass(data_type, (TensorDataset, Accelerated2DTensorDataset)):
# retrieving data from these types changes the type, so we change it back
self._data = data_type(*self._data)
def __len__(self) -> int:
return self._num_batches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment