Skip to content

Instantly share code, notes, and snippets.

@YodaEmbedding
Last active May 6, 2022 21:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YodaEmbedding/8d1d32748cc546ce49ee9dea82c6f2aa to your computer and use it in GitHub Desktop.
Save YodaEmbedding/8d1d32748cc546ce49ee9dea82c6f2aa to your computer and use it in GitHub Desktop.
CompressAI Pytorch Lightning

CompressAI Pytorch Lightning example

import pytorch_lightning as pl
from compressai.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
class CLICDataModule(pl.LightningDataModule):
def __init__(self, data_dir, patch_size, **dataloader_kwargs):
super().__init__()
self.data_dir = data_dir
self.train_transform = transforms.Compose(
[
transforms.RandomCrop(patch_size),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
]
)
self.val_transform = transforms.Compose(
[
transforms.CenterCrop(patch_size),
transforms.ToTensor(),
]
)
self.test_transform = self.val_transform
self.dataloader_kwargs = dataloader_kwargs
def prepare_data(self):
pass
def setup(self, stage=None):
self.train_dataset = ImageFolder(
self.data_dir, split="train", transform=self.train_transform
)
self.val_dataset = ImageFolder(
self.data_dir, split="valid", transform=self.val_transform
)
self.test_dataset = ImageFolder(
self.data_dir, split="test", transform=self.test_transform
)
def train_dataloader(self):
return DataLoader(
self.train_dataset, shuffle=True, **self.dataloader_kwargs
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, shuffle=False, **self.dataloader_kwargs
)
def test_dataloader(self):
return DataLoader(
self.test_dataset, shuffle=False, **self.dataloader_kwargs
)
from typing import Callable, Optional
import pytorch_lightning as pl
import torch
import torch.optim as optim
from omegaconf import OmegaConf
from sfu_compression.losses import RateDistortionLoss
from sfu_compression.models import SFUDenoiseScalable
from sfu_compression.utils import (
create_noise_model,
git_branch_name,
git_common_ancestor_hash,
git_current_hash,
)
class LitSFUDenoiseScalable(pl.LightningModule):
def __init__(
self,
conf: Optional[OmegaConf] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters(conf)
self.save_hyperparameters(kwargs)
self.model = SFUDenoiseScalable(
N=self.hparams.architecture.num_channels,
BASE_N=self.hparams.architecture.num_base_channels,
)
self.criterion = RateDistortionLoss(
lmbda=self.hparams.training.lmbda,
w1d=self.hparams.training.w1d,
w2d=self.hparams.training.w2d,
w3d=self.hparams.training.w3d,
w1r=self.hparams.training.w1r,
w2r=self.hparams.training.w2r,
w3r=self.hparams.training.w3r,
)
self.noise_model = create_noise_model(self.hparams.noise_model)
self.automatic_optimization = False
def forward(self, x):
# TODO compress, decompress?
return self.model(x)
def training_step(self, batch, batch_idx):
x = batch
x_noise = self.noise_model(x)
optimizer, aux_optimizer = self.optimizers()
optimizer.zero_grad()
aux_optimizer.zero_grad()
out_net = self.model(x_noise)
out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x})
loss = out_criterion["loss"]
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.hparams.training.clip_max_norm
)
optimizer.step()
aux_loss = self.model.aux_loss()
self.manual_backward(aux_loss)
aux_optimizer.step()
log_dict = {**out_criterion, "aux_loss": aux_loss}
log_dict = {f"train/{k}": v for k, v in log_dict.items()}
self.log_dict(log_dict)
def validation_step(self, batch, batch_idx):
x = batch
x_noise = self.noise_model(x)
out_net = self.model(x_noise)
out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x})
aux_loss = self.model.aux_loss()
log_dict = {**out_criterion, "aux_loss": aux_loss}
log_dict = {f"val/{k}": v for k, v in log_dict.items()}
log_dict["val_loss"] = out_criterion["loss"]
self.log_dict(log_dict)
def validation_epoch_end(self, outputs):
sch = self.lr_schedulers()
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
sch.step(self.trainer.callback_metrics["val/loss"])
else:
raise Exception
def test_step(self, batch, batch_idx):
x = batch
x_noise = self.noise_model(x)
enc_dict = self.model.compress(x_noise)
encoded = [x[0] for x in enc_dict["strings"]]
result = self.model.decompress(**enc_dict)
x_hat = result["x_hat"].numpy()[0]
# TODO log metrics, etc; on_epoch, on_step
def configure_optimizers(self):
optimizer, aux_optimizer = configure_optimizers(
self.model, self.hparams.training
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
return (
{
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val/loss",
},
},
{
"optimizer": aux_optimizer,
},
)
def on_fit_start(self):
params = {
"git": {
"branch_name": git_branch_name(),
"hash": git_current_hash(),
"master_hash": git_common_ancestor_hash(),
},
**self.hparams,
}
metrics = {"hp/metric": -1}
self.logger.log_hyperparams(params, metrics)
def on_load_checkpoint(self, checkpoint):
prefix = "model."
checkpoint["state_dict"] = {
f"{prefix}{k}": v for k, v in checkpoint["state_dict"].items()
}
def on_save_checkpoint(self, checkpoint):
prefix_len = len("model.")
checkpoint["state_dict"] = {
k[prefix_len:]: v for k, v in checkpoint["state_dict"].items()
}
def configure_optimizers(net, args):
"""Separate parameters for the main optimizer and the auxiliary optimizer.
Return two optimizers"""
parameters = {
n
for n, p in net.named_parameters()
if not n.endswith(".quantiles") and p.requires_grad
}
aux_parameters = {
n
for n, p in net.named_parameters()
if n.endswith(".quantiles") and p.requires_grad
}
# Make sure we don't have an intersection of parameters
params_dict = dict(net.named_parameters())
inter_params = parameters & aux_parameters
union_params = parameters | aux_parameters
assert len(inter_params) == 0
assert len(union_params) - len(params_dict.keys()) == 0
optimizer = optim.Adam(
(params_dict[n] for n in sorted(parameters)),
lr=args.learning_rate,
)
aux_optimizer = optim.Adam(
(params_dict[n] for n in sorted(aux_parameters)),
lr=args.aux_learning_rate,
)
return optimizer, aux_optimizer
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from torchinfo import summary
from sfu_compression.datasets import CLICDataModule
from sfu_compression.models import LitSFUDenoiseScalable
from sfu_compression.utils import parse_args_training
def load_model_from_args(args: OmegaConf):
continue_from = args.training_params.continue_from
if continue_from != "":
checkpoint_path = f"{continue_from}/checkpoints/last.ckpt"
return LitSFUDenoiseScalable.load_from_checkpoint(checkpoint_path)
conf = OmegaConf.merge(args.hparams, {"noise_model": args.noise_model})
return LitSFUDenoiseScalable(conf=conf)
def main():
args = parse_args_training()
if args.training_params.seed is not None:
pl.seed_everything(args.training_params.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
data_module = CLICDataModule(
data_dir=args.other.dataset,
patch_size=args.training_params.patch_size,
batch_size=args.training_params.batch_size,
num_workers=args.training_params.num_workers,
pin_memory=True,
)
model = load_model_from_args(args)
# Show network with layer sizes.
h, w = args.training_params.patch_size
empty_img = (1, 3, h, w)
summary(model.model, [empty_img])
checkpoint_callback = ModelCheckpoint(
monitor="val/loss",
filename="{epoch:04d}-{val_loss:.2f}",
save_last=True,
save_top_k=1,
mode="min",
)
early_stopping_callback = EarlyStopping("val/loss", patience=15)
lr_monitor_callback = LearningRateMonitor(logging_interval="epoch")
tb_logger = TensorBoardLogger(
save_dir="lightning_logs",
name="",
default_hp_metric=False,
)
trainer_kwargs = dict(
callbacks=[
checkpoint_callback,
early_stopping_callback,
lr_monitor_callback,
],
logger=tb_logger,
)
trainer_kwargs = {**args.pytorch_lightning_trainer, **trainer_kwargs}
trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(model, datamodule=data_module)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment