Skip to content

Instantly share code, notes, and snippets.

@ashunigion
Created June 12, 2019 00:27
Show Gist options
  • Save ashunigion/d70faed185988672239b22f6b0be81a9 to your computer and use it in GitHub Desktop.
Save ashunigion/d70faed185988672239b22f6b0be81a9 to your computer and use it in GitHub Desktop.
Creation of train, test,validation dataloader
import torch
from torch.utils.data import TensorDataset, DataLoader
# create Tensor datasets
train_data = TensorDataset(torch.from_numpy(train_x), torch.from_numpy(train_y))
valid_data = TensorDataset(torch.from_numpy(valid_x), torch.from_numpy(valid_y))
test_data = TensorDataset(torch.from_numpy(test_x), torch.from_numpy(test_y))
# dataloaders
batch_size = 50
# make sure to SHUFFLE your data
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
valid_loader = DataLoader(valid_data, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment