Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created September 5, 2021 00:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lyken17/d4bdd0afd97283ecec1864ba348e0e9c to your computer and use it in GitHub Desktop.
Save Lyken17/d4bdd0afd97283ecec1864ba348e0e9c to your computer and use it in GitHub Desktop.
Sample ImageNet DataLoader
import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.multiprocessing as mp
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data = "/dataset/imagenet"
workers = 12
batch_size = 256
traindir = os.path.join(data, 'train')
valdir = os.path.join(data, 'val')
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
])
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=False,
num_workers=workers, pin_memory=True, sampler=None)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
])),
batch_size=batch_size, shuffle=False,
num_workers=workers, pin_memory=True)
total_batches = 0
start = time.time()
for i, (data, label) in enumerate(train_loader):
passed_time = time.time() - start
total_batches += data.shape[0]
print(f"[{i}] Loading speed: {(total_batches / passed_time):.2f} imgs/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment