Skip to content

Instantly share code, notes, and snippets.

@keuv-grvl
Created April 12, 2023 12:33
Show Gist options
  • Save keuv-grvl/6efe35e769be80020d60ea83a034e491 to your computer and use it in GitHub Desktop.
Save keuv-grvl/6efe35e769be80020d60ea83a034e491 to your computer and use it in GitHub Desktop.
Wrap a HF `datasets.Dataset` into `torch.utils.data.Dataset`
import torch
import datasets
class HFDataset(torch.utils.data.Dataset):
def __init__(self, dset: datasets.Dataset):
self.dset = dset
def __getitem__(self, idx):
return self.dset[idx]
def __len__(self):
return len(self.dset)
if __name__ == "__main__":
# load a dataset from HF hub
hf_ds = datasets.load_dataset("the-dataset")
# process your data
def trsfm_fn(example) -> dict:
return {...: ...}
hf_ds = hf_ds.map(...).sort(...).filter(...).remove_columns(...).with_transform(trsfm_fn)
# wrap as a pytorch Dataset
train_ds = HFDataset(hf_ds["train"])
train_ds[123:132] # the transform function is called once, good
# build dataloader
train_dl = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=3, drop_last=True)
ii = iter(train_dl)
example_batch = next(tain_dl)
# NOTE: .with_transform(...) is applied to each row
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment