Skip to content

Instantly share code, notes, and snippets.

@pedrohbtp
Last active December 11, 2018 13:40
Show Gist options
  • Save pedrohbtp/4ac8b6e470d86c42d73875189de21a9e to your computer and use it in GitHub Desktop.
Save pedrohbtp/4ac8b6e470d86c42d73875189de21a9e to your computer and use it in GitHub Desktop.
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
class ExampleDataset(Dataset):
"""Example Dataset"""
def __init__(self, csv_file):
"""
csv_file (string): Path to the csv file containing data.
"""
self.data_frame = pd.read_csv(csv_file)
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
return self.data_frame[idx]
# instantiates the dataset
example_dataset = ExampleDataset('my_data_file.csv')
# batch size: number of samples returned per iteration
# shuffle: Flag to shuffle the data before reading so you don't read always in the same order
# num_workers: used to load the data in parallel
example_data_loader = DataLoader(example_dataset, , batch_size=4, shuffle=True, num_workers=4)
# Loops over the data 4 samples at a time
for batch_index, batch in enumerate(example_data_loader):
print(batch_index, batch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment