Skip to content

Instantly share code, notes, and snippets.

@adamoudad
Created March 20, 2021 21:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamoudad/cc4d58947e53ca752b7b0279bc0687e7 to your computer and use it in GitHub Desktop.
Save adamoudad/cc4d58947e53ca752b7b0279bc0687e7 to your computer and use it in GitHub Desktop.
from torch.utils.data import TensorDataset, DataLoader
train_data = TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train))
valid_data = TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)
valid_loader = DataLoader(valid_data, shuffle=True, batch_size=batch_size, drop_last=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment