Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Demonstrates issue of self.hparams not being restored when loading from checkpoint. Details can be found here:
from abc import abstractmethod
import torch
from torch import nn as nn
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, to_categorical
from pytorch_lightning import Trainer
from import Dataset, DataLoader
from torch.nn import functional as F
class ModelBase(pl.LightningModule):
def __init__(self, pretrained_hparams: bool, **kwargs): # **kwargs # sample_rate
print(f"Init ModelBase, hparams:\n{self.hparams}\n")
print(f"Init ModelBase after, hparams:\n{self.hparams}\n")
# use PANNs.load_from_checkpoint when loading weights after transfer learning
if pretrained_hparams:
# save all arugments in self.hparams
print("Argument hparams: ", self.hparams)
# needed hparams for non-lightning pre-trained weights
# print("All hparams: ", self.hparams)
def forward(self, x):
def set_pretrained_hparams(self):
if self.hparams["sample_rate"] == 8000:
self.hparams["hlayer1"] = 400
elif self.hparams["sample_rate"] == 16000:
self.hparams["hlayer1"] = 800
self.hparams["classes_num"] = 3
def load_non_lightning_weights(self, weights_path):
# checkpoint = torch.load(weights_path)
# self.load_state_dict(checkpoint['model'])
# 1 variant
class Linear3(ModelBase):
def __init__(self, sample_rate, **kwargs):
print(f"Init Linear3, hparams:\n{self.hparams}\n")
super().__init__(sample_rate=sample_rate, **kwargs)
print(f"Init Linear3 after, hparams:\n{self.hparams}\n")
# 1 sec of audio
self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
self.hidden_layer = nn.Linear(self.hparams["hlayer1"], 128, bias=True)
self.output_layer = nn.Linear(128, self.hparams["classes_num"], bias=True)
def forward(self, input):
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
class ModelTrainer(pl.LightningModule):
# arguments should NOT be positional due to inherence; always have a default value
def __init__(self, learning_rate=1e-3, **kwargs): # **kwargs
print(f"Init ModelTrainer, hparams:\n{self.hparams}\n")
# everything included in init call will be included in self.hparams (here only kwargs is included);
# meaning only those will be saved in a .ckpt file
super().__init__(learning_rate=learning_rate, **kwargs) # **kwargs
print(f"Init ModelTrainer after, hparams:\n{self.hparams}\n")
self.criterion = nn.CrossEntropyLoss()
def calculate_loss(self, prediction, target):
"""Binary crossentropy loss"""
# loss = F.binary_cross_entropy_with_logits(prediction, target)
loss = self.criterion(prediction, target)
return loss
def training_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss)
return result
def validation_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
result.log('val_acc', accuracy(prediction, target))
return result
def test_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.EvalResult() # checkpoint_on=loss
result.log('test_loss', loss)
result.log('test_acc', accuracy(prediction, target)) # to_categorical()
return result
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams["learning_rate"])
class MyModel(ModelTrainer, Linear3):
def __init__(self, unfreeze_epoch=1, **kwargs):
# arguments passed here are stored in self.hparams
print(f"Init MyModel, hparams:\n{self.hparams}\n")
super().__init__(unfreeze_epoch=unfreeze_epoch, **kwargs) # unfreeze_epoch=unfreeze_epoch, **kwargs
print(f"Init MyModel after, hparams:\n{self.hparams}\n")
# print("hparams after init: ", self.hparams)
# self.unfreeze_epoch = unfreeze_epoch
# self.freeze()
def forward(self, input, mixup_lambda=None):
# unfreeze deep layers after unfreeze_epoch epochs
# if self.current_epoch == self.unfreeze_epoch:
# self.unfreeze()
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
class SimpleDataset(Dataset):
def __init__(self, sample_rate=8000):
self.sample_rate = sample_rate
def __len__(self):
return 16
def __getitem__(self, idx):
# 0, 1 or 2
target = torch.randint(0, 3, size=(1, )).squeeze()
# size 8000/16000 of 0.0, 0.5, or 1.0
input = torch.full((self.sample_rate,), (target.float()/2).item())
# torch.empty(self.sample_rate,).fill_(target.float()/2)
return input, target
class SimpleDatamodule(pl.LightningDataModule):
def setup(self, stage: str = None):
def train_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
def val_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
# dataset = self._set_dataset_split("val")
# return DataLoader(dataset, batch_size=self.hparams["batch_size"],
# sampler=SubsetRandomSampler(dataset.indices), num_workers=4)
def test_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
if __name__ == '__main__':
sr = 8000
checkpoint_location = "example.ckpt"
# network
model = MyModel(sample_rate=8000, pretrained_hparams=True)
print("After all init, hparams:\n{self.hparams}\n")
# data
dm = SimpleDatamodule()
# train
trainer = Trainer(max_epochs=4, deterministic=True) # gpus=1,, dm)
# save
# check model contents
print(f"\n\nModel save completed. Checking contents saved model...")
checkpoint = torch.load(checkpoint_location) # , map_location='cuda:0'
print(f"Checkpoint hyper parameters:\n{checkpoint['hyper_parameters']}") # .keys() # ['state_dict']
# ERROR: load weights into new model
print("\nContents check completed. Trying to restore model with checkpoint...")
model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
# KeyError: 'sample_rate'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment