The PyTorch MNIST dataset is SLOW by default, because it wants to conform to the usual interface of returning a PIL image. This is unnecessary if you just want a normalized MNIST and are not interested in image transforms (such as rotation, cropping). By folding the normalization into the dataset initialization you can save your CPU and speed up training by 2-3x.
The bottleneck when training on MNIST with a GPU and a small-ish model is the CPU. In fact, even with six dataloader workers on a six core i7, the GPU utilization is only ~5-10%. Using FastMNIST increases GPU utilization to ~20-25% and reduces CPU utilization to near zero. On my particular model the steps per second with batch size 64 went from ~150 to ~500.
Instead of the default MNIST dataset, use this:
import torch from torchvision.datasets import MNIST