Skip to content

Instantly share code, notes, and snippets.

@danesherbs
Last active September 7, 2022 02:15
Show Gist options
  • Save danesherbs/138ac40e66e8633ab9e23320ff0222de to your computer and use it in GitHub Desktop.
Save danesherbs/138ac40e66e8633ab9e23320ff0222de to your computer and use it in GitHub Desktop.
A class which makes a PyTorch dataset from a dictionary of tensors
class DictDataset(torch.utils.data.Dataset):
"""Makes a dataset from a dictionary of tensors"""
def __init__(self, inputs: Dict[str, Tensor]):
assert len(inputs) > 0, "inputs must be non-empty"
keys = list(inputs.keys())
key = keys[0]
self._length = inputs[key].shape[0]
for v in inputs.values():
assert v.shape[0] == self._length, "all tensors must have same shape in first dimension"
self._inputs = inputs
def __len__(self) -> int:
return self._length
def __getitem__(self, idx) -> Dict[str, Tensor]:
return {k: v[idx] for k, v in self._inputs.items()}
def dictdataset_collate_fn(batch: Sequence[Dict[str, Tensor]]) -> Dict[str, Tensor]:
"""Collate function for DictDataset"""
assert len(batch) > 0, "batch must be non-empty"
keys = list(batch[0].keys())
return {k: torch.vstack([example[k] for example in batch]) for k in keys}
inputs = {
"input_ids": torch.randint(0, 100, (10, 20)),
"attention_mask": torch.ones((10, 20)),
}
dataset = DictDataset(inputs)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=4,
collate_fn=dictdataset_collate_fn,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment