Created
November 13, 2023 20:01
-
-
Save tchaton/da14112a3492db6c6e44795a288dd56f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright The Lightning AI team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""MNIST autoencoder example. | |
To run: python autoencoder.py --trainer.max_epochs=50 | |
""" | |
from os import path | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from lightning.pytorch import LightningDataModule, LightningModule, Trainer, callbacks, cli_lightning_logo | |
from lightning.pytorch.cli import LightningCLI | |
from lightning.pytorch.demos.mnist_datamodule import MNIST | |
from lightning.pytorch.utilities import rank_zero_only | |
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE | |
from torch import nn | |
from torch.utils.data import DataLoader, random_split | |
if _TORCHVISION_AVAILABLE: | |
import torchvision | |
from torchvision import transforms | |
from torchvision.utils import save_image | |
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") | |
class ImageSampler(callbacks.Callback): | |
def __init__( | |
self, | |
num_samples: int = 3, | |
nrow: int = 8, | |
padding: int = 2, | |
normalize: bool = True, | |
value_range: Optional[Tuple[int, int]] = None, | |
scale_each: bool = False, | |
pad_value: int = 0, | |
) -> None: | |
""" | |
Args: | |
num_samples: Number of images displayed in the grid. Default: ``3``. | |
nrow: Number of images displayed in each row of the grid. | |
The final grid size is ``(B / nrow, nrow)``. Default: ``8``. | |
padding: Amount of padding. Default: ``2``. | |
normalize: If ``True``, shift the image to the range (0, 1), | |
by the min and max values specified by :attr:`range`. Default: ``False``. | |
value_range: Tuple (min, max) where min and max are numbers, | |
then these numbers are used to normalize the image. By default, min and max | |
are computed from the tensor. | |
scale_each: If ``True``, scale each image in the batch of | |
images separately rather than the (min, max) over all images. Default: ``False``. | |
pad_value: Value for the padded pixels. Default: ``0``. | |
""" | |
if not _TORCHVISION_AVAILABLE: # pragma: no cover | |
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") | |
super().__init__() | |
self.num_samples = num_samples | |
self.nrow = nrow | |
self.padding = padding | |
self.normalize = normalize | |
self.value_range = value_range | |
self.scale_each = scale_each | |
self.pad_value = pad_value | |
def _to_grid(self, images): | |
return torchvision.utils.make_grid( | |
tensor=images, | |
nrow=self.nrow, | |
padding=self.padding, | |
normalize=self.normalize, | |
value_range=self.value_range, | |
scale_each=self.scale_each, | |
pad_value=self.pad_value, | |
) | |
@rank_zero_only | |
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: | |
if not _TORCHVISION_AVAILABLE: | |
return | |
images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples))) | |
images_flattened = images.view(images.size(0), -1) | |
# generate images | |
with torch.no_grad(): | |
pl_module.eval() | |
images_generated = pl_module(images_flattened.to(pl_module.device)) | |
pl_module.train() | |
if trainer.current_epoch == 0: | |
save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png") | |
save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png") | |
class LitAutoEncoder(LightningModule): | |
""" | |
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | |
LitAutoEncoder( | |
(encoder): ... | |
(decoder): ... | |
) | |
""" | |
def __init__(self, hidden_dim: int = 64, learning_rate=10e-3): | |
super().__init__() | |
self.save_hyperparameters() | |
self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3)) | |
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28)) | |
def forward(self, x): | |
z = self.encoder(x) | |
return self.decoder(z) | |
def training_step(self, batch, batch_idx): | |
return self._common_step(batch, batch_idx, "train") | |
def validation_step(self, batch, batch_idx): | |
self._common_step(batch, batch_idx, "val") | |
def test_step(self, batch, batch_idx): | |
self._common_step(batch, batch_idx, "test") | |
def predict_step(self, batch, batch_idx, dataloader_idx=None): | |
x = self._prepare_batch(batch) | |
return self(x) | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | |
def _prepare_batch(self, batch): | |
x, _ = batch | |
return x.view(x.size(0), -1) | |
def _common_step(self, batch, batch_idx, stage: str): | |
x = self._prepare_batch(batch) | |
loss = F.mse_loss(x, self(x)) | |
self.log(f"{stage}_loss", loss, on_step=True) | |
return loss | |
class MyDataModule(LightningDataModule): | |
def __init__(self, batch_size: int = 32): | |
super().__init__() | |
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) | |
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) | |
self.mnist_train, self.mnist_val = random_split( | |
dataset, [55000, 5000], generator=torch.Generator().manual_seed(42) | |
) | |
self.batch_size = batch_size | |
def train_dataloader(self): | |
return DataLoader(self.mnist_train, batch_size=self.batch_size) | |
def val_dataloader(self): | |
return DataLoader(self.mnist_val, batch_size=self.batch_size) | |
def test_dataloader(self): | |
return DataLoader(self.mnist_test, batch_size=self.batch_size) | |
def predict_dataloader(self): | |
return DataLoader(self.mnist_test, batch_size=self.batch_size) | |
def cli_main(): | |
cli = LightningCLI( | |
LitAutoEncoder, | |
MyDataModule, | |
seed_everything_default=1234, | |
run=False, # used to de-activate automatic fitting. | |
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10}, | |
save_config_kwargs={"overwrite": True}, | |
) | |
cli.trainer.fit(cli.model, datamodule=cli.datamodule) | |
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) | |
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule) | |
print(predictions[0]) | |
if __name__ == "__main__": | |
cli_lightning_logo() | |
cli_main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment