Skip to content

Instantly share code, notes, and snippets.

@tchaton
Last active December 6, 2020 19:04
Show Gist options
  • Save tchaton/c26885c2f1ecbb53c5bf124545cec7b9 to your computer and use it in GitHub Desktop.
Save tchaton/c26885c2f1ecbb53c5bf124545cec7b9 to your computer and use it in GitHub Desktop.
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = PATH, batch_size):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_test = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment