Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created July 22, 2021 20:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ResidentMario/16525c985fa3aaf9943bcfae0d3e0022 to your computer and use it in GitHub Desktop.
Save ResidentMario/16525c985fa3aaf9943bcfae0d3e0022 to your computer and use it in GitHub Desktop.
# import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import pandas as pd
from PIL import Image
import time
import argparse
class TestDataset(Dataset):
def __init__(self):
super().__init__()
self.train_folders = [
"/mnt/resized train 15/",
"/mnt/resized train 19/"
]
self.labels = [
pd.read_csv("/mnt/labels/trainLabels15.csv"),
pd.read_csv("/mnt/labels/trainLabels19.csv"),
]
self.dir_breakpoint_idx = len(self.labels[0])
self.transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((1024, 1024)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomPerspective(),
torchvision.transforms.ToTensor()
])
def __len__(self):
return sum(len(labelset) for labelset in self.labels)
def __getitem__(self, i):
if i >= self.dir_breakpoint_idx:
i = i % self.dir_breakpoint_idx
labelset = self.labels[1]
train_folder = self.train_folders[1]
else:
labelset = self.labels[0]
train_folder = self.train_folders[0]
filename = f"{labelset.iloc[i, 0]}.jpg"
filepath = f"{train_folder}{filename}"
img = Image.open(filepath)
return self.transform(img)
def get_dataset():
return TestDataset()
def get_dataloader(dataset, batch_size, num_workers):
# num_workers controls multiprocessing concurrency
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
return dataloader
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, help="Number of images per batch")
parser.add_argument("--num-workers", type=int, help="Number of multiprocessed workers to used for loading data")
parser.add_argument("--sleep", type=float, help="Amount of time to sleep (in seconds) between disk reads")
args = parser.parse_args()
dataset = get_dataset()
dataloader = get_dataloader(dataset, args.batch_size, args.num_workers)
while True:
for i, batch in enumerate(dataloader):
# do nothing, just read from disk
time.sleep(args.sleep)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment