Created
September 24, 2021 10:23
-
-
Save raman-r-4978/1f134ffc52ce8d9e2da03c3c165b01e9 to your computer and use it in GitHub Desktop.
PL issue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import platform | |
from typing import Optional | |
from urllib.error import HTTPError | |
from warnings import warn | |
from pytorch_lightning import LightningDataModule | |
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE | |
from torch.utils.data import DataLoader, random_split | |
_DATASETS_PATH = os.environ.get("PATH_DATASETS", ".") | |
if _TORCHVISION_AVAILABLE: | |
from torchvision import transforms as transform_lib | |
_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) | |
if _TORCHVISION_MNIST_AVAILABLE: | |
try: | |
from torchvision.datasets import MNIST | |
MNIST(_DATASETS_PATH, download=True) | |
except HTTPError as e: | |
print(f"Error {e} downloading `torchvision.datasets.MNIST`") | |
_TORCHVISION_MNIST_AVAILABLE = False | |
if not _TORCHVISION_MNIST_AVAILABLE: | |
print("`torchvision.datasets.MNIST` not available. Using our hosted version") | |
from tests.helpers.datasets import MNIST | |
class MNISTDataModule(LightningDataModule): | |
"""Standard MNIST, train, val, test splits and transforms. | |
>>> MNISTDataModule() # doctest: +ELLIPSIS | |
<...mnist_datamodule.MNISTDataModule object at ...> | |
""" | |
name = "mnist" | |
def __init__( | |
self, | |
data_dir: str = _DATASETS_PATH, | |
val_split: int = 5000, | |
num_workers: int = 16, | |
normalize: bool = False, | |
seed: int = 42, | |
batch_size: int = 32, | |
*args, | |
**kwargs, | |
): | |
""" | |
Args: | |
data_dir: where to save/load the data | |
val_split: how many of the training images to use for the validation split | |
num_workers: how many workers to use for loading data | |
normalize: If true applies image normalize | |
seed: starting seed for RNG. | |
batch_size: desired batch size. | |
""" | |
super().__init__(*args, **kwargs) | |
if num_workers and platform.system() == "Windows": | |
# see: https://stackoverflow.com/a/59680818 | |
warn( | |
f"You have requested num_workers={num_workers} on Windows," | |
" but currently recommended is 0, so we set it for you" | |
) | |
num_workers = 0 | |
self.dims = (1, 28, 28) | |
self.data_dir = data_dir | |
self.val_split = val_split | |
self.num_workers = num_workers | |
self.normalize = normalize | |
self.seed = seed | |
self.batch_size = batch_size | |
self.dataset_train = ... | |
self.dataset_val = ... | |
self.test_transforms = self.default_transforms | |
@property | |
def num_classes(self): | |
return 10 | |
def prepare_data(self): | |
"""Saves MNIST files to `data_dir`""" | |
MNIST(self.data_dir, train=True, download=True) | |
MNIST(self.data_dir, train=False, download=True) | |
def setup(self, stage: Optional[str] = None): | |
"""Split the train and valid dataset.""" | |
extra = ( | |
dict(transform=self.default_transforms) if self.default_transforms else {} | |
) | |
dataset = MNIST(self.data_dir, train=True, download=False, **extra) | |
train_length = len(dataset) | |
self.dataset_train, self.dataset_val = random_split( | |
dataset, [train_length - self.val_split, self.val_split] | |
) | |
def train_dataloader(self): | |
"""MNIST train set removes a subset to use for validation.""" | |
loader = DataLoader( | |
self.dataset_train, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
drop_last=True, | |
pin_memory=True, | |
) | |
return loader | |
def val_dataloader(self): | |
"""MNIST val set uses a subset of the training set for validation.""" | |
loader = DataLoader( | |
self.dataset_val, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
drop_last=True, | |
pin_memory=True, | |
) | |
return [loader, loader] | |
def test_dataloader(self): | |
"""MNIST test set uses the test split.""" | |
extra = dict(transform=self.test_transforms) if self.test_transforms else {} | |
dataset = MNIST(self.data_dir, train=False, download=False, **extra) | |
loader = DataLoader( | |
dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
drop_last=True, | |
pin_memory=True, | |
) | |
return loader | |
@property | |
def default_transforms(self): | |
if not _TORCHVISION_AVAILABLE: | |
return None | |
if self.normalize: | |
mnist_transforms = transform_lib.Compose( | |
[ | |
transform_lib.ToTensor(), | |
transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | |
] | |
) | |
else: | |
mnist_transforms = transform_lib.ToTensor() | |
return mnist_transforms |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment