Skip to content

Instantly share code, notes, and snippets.

@akkefa
Created December 20, 2022 07:04
Show Gist options
  • Save akkefa/64ddc3fc8d74406f814ca348be4a8725 to your computer and use it in GitHub Desktop.
Save akkefa/64ddc3fc8d74406f814ca348be4a8725 to your computer and use it in GitHub Desktop.
Pytorch Dataloader Example
import torch
from torch.utils.data import DataLoader, TensorDataset
# Define the dataset
X = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
y = torch.Tensor([0, 1, 0, 1])
dataset = TensorDataset(X, y)
# Create the dataloader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# Iterate through the dataloader and process the data
for i, (inputs, labels) in enumerate(dataloader):
print(f'Batch {i}:')
print(f' Inputs: {inputs}')
print(f' Labels: {labels}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment