Skip to content

Instantly share code, notes, and snippets.

@prateekjoshi565
Last active July 17, 2020 17:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prateekjoshi565/3182eb33c4cbff9ed4c2b05636c1377c to your computer and use it in GitHub Desktop.
Save prateekjoshi565/3182eb33c4cbff9ed4c2b05636c1377c to your computer and use it in GitHub Desktop.
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
#define a batch size
batch_size = 32
# wrap tensors
train_data = TensorDataset(train_seq, train_mask, train_y)
# sampler for sampling the data during training
train_sampler = RandomSampler(train_data)
# dataLoader for train set
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
# wrap tensors
val_data = TensorDataset(val_seq, val_mask, val_y)
# sampler for sampling the data during training
val_sampler = SequentialSampler(val_data)
# dataLoader for validation set
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment