Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Train 2-3x faster on MNIST with much less CPU usage by making a few simple changes to the PyTorch provided one.

The PyTorch MNIST dataset is SLOW by default, because it wants to conform to the usual interface of returning a PIL image. This is unnecessary if you just want a normalized MNIST and are not interested in image transforms (such as rotation, cropping). By folding the normalization into the dataset initialization you can save your CPU and speed up training by 2-3x.

The bottleneck when training on MNIST with a GPU and a small-ish model is the CPU. In fact, even with six dataloader workers on a six core i7, the GPU utilization is only ~5-10%. Using FastMNIST increases GPU utilization to ~20-25% and reduces CPU utilization to near zero. On my particular model the steps per second with batch size 64 went from ~150 to ~500.

Instead of the default MNIST dataset, use this:

import torch
from torchvision.datasets import MNIST

device = torch.device('cuda')

class FastMNIST(MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Scale data to [0,1] =
        # Normalize it with the usual MNIST mean and std =
        # Put both data and targets on GPU in advance, self.targets =,

    def __getitem__(self, index):
            index (int): Index

            tuple: (image, target) where target is index of the target class.
        img, target =[index], self.targets[index]

        return img, target

And call the dataloader like this:

from import DataLoader

train_dataset = FastMNIST('data/MNIST', train=True, download=True)
test_dataset = FastMNIST('data/MNIST', train=False, download=True)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=10000, shuffle=False, num_workers=0)

Results in 2-3x speedup (500it/s on a 1080Ti and a smallish MLP), uses near zero CPU (compared to full CPU usage normally).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment