Skip to content

Instantly share code, notes, and snippets.

@harsh-99
Last active May 12, 2021 14:42
Show Gist options
  • Save harsh-99/edc15b2ac7148f060b1ca504595710ff to your computer and use it in GitHub Desktop.
Save harsh-99/edc15b2ac7148f060b1ca504595710ff to your computer and use it in GitHub Desktop.
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
mnist_test = MNIST(os.getcwd(), train=False, download=True)
# The dataloaders handle shuffling, batching, etc...
train_dataloader = DataLoader(mnist_train, batch_size=64)
val_dataloader = DataLoader(mnist_val, batch_size=64)
test_dataloader = DataLoader(mnist_test, batch_size=64)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment