Skip to content

Instantly share code, notes, and snippets.

@NumesSanguis
Last active October 8, 2020 03:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NumesSanguis/558add315378cda55e28b6d5f63f56b2 to your computer and use it in GitHub Desktop.
Save NumesSanguis/558add315378cda55e28b6d5f63f56b2 to your computer and use it in GitHub Desktop.
Demonstrates issue of self.hparams not being restored when loading from checkpoint. Details can be found here: https://forums.pytorchlightning.ai/t/hparams-not-restored-when-using-load-from-checkpoint-default-argument-values-are-the-problem/237
from abc import abstractmethod
import torch
from torch import nn as nn
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, to_categorical
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
class ModelBase(pl.LightningModule):
def __init__(self, pretrained_hparams: bool, **kwargs): # **kwargs # sample_rate
print(f"Init ModelBase, hparams:\n{self.hparams}\n")
super().__init__()
print(f"Init ModelBase after, hparams:\n{self.hparams}\n")
# use PANNs.load_from_checkpoint when loading weights after transfer learning
if pretrained_hparams:
# save all arugments in self.hparams
self.save_hyperparameters()
print("Argument hparams: ", self.hparams)
# needed hparams for non-lightning pre-trained weights
self.set_pretrained_hparams()
# print("All hparams: ", self.hparams)
@abstractmethod
def forward(self, x):
pass
def set_pretrained_hparams(self):
if self.hparams["sample_rate"] == 8000:
self.hparams["hlayer1"] = 400
elif self.hparams["sample_rate"] == 16000:
self.hparams["hlayer1"] = 800
self.hparams["classes_num"] = 3
def load_non_lightning_weights(self, weights_path):
# checkpoint = torch.load(weights_path)
# self.load_state_dict(checkpoint['model'])
pass
# 1 variant
class Linear3(ModelBase):
def __init__(self, sample_rate, **kwargs):
print(f"Init Linear3, hparams:\n{self.hparams}\n")
super().__init__(sample_rate=sample_rate, **kwargs)
print(f"Init Linear3 after, hparams:\n{self.hparams}\n")
# 1 sec of audio
self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
self.hidden_layer = nn.Linear(self.hparams["hlayer1"], 128, bias=True)
self.output_layer = nn.Linear(128, self.hparams["classes_num"], bias=True)
def forward(self, input):
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
class ModelTrainer(pl.LightningModule):
# arguments should NOT be positional due to inherence; always have a default value
def __init__(self, learning_rate=1e-3, **kwargs): # **kwargs
print(f"Init ModelTrainer, hparams:\n{self.hparams}\n")
# everything included in init call will be included in self.hparams (here only kwargs is included);
# meaning only those will be saved in a .ckpt file
super().__init__(learning_rate=learning_rate, **kwargs) # **kwargs
print(f"Init ModelTrainer after, hparams:\n{self.hparams}\n")
self.criterion = nn.CrossEntropyLoss()
def calculate_loss(self, prediction, target):
"""Binary crossentropy loss"""
# loss = F.binary_cross_entropy_with_logits(prediction, target)
loss = self.criterion(prediction, target)
return loss
def training_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss)
return result
def validation_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
result.log('val_acc', accuracy(prediction, target))
return result
def test_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.EvalResult() # checkpoint_on=loss
result.log('test_loss', loss)
result.log('test_acc', accuracy(prediction, target)) # to_categorical()
return result
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams["learning_rate"])
class MyModel(ModelTrainer, Linear3):
def __init__(self, unfreeze_epoch=1, **kwargs):
# arguments passed here are stored in self.hparams
print(f"Init MyModel, hparams:\n{self.hparams}\n")
super().__init__(unfreeze_epoch=unfreeze_epoch, **kwargs) # unfreeze_epoch=unfreeze_epoch, **kwargs
print(f"Init MyModel after, hparams:\n{self.hparams}\n")
# print("hparams after init: ", self.hparams)
# self.unfreeze_epoch = unfreeze_epoch
# self.freeze()
def forward(self, input, mixup_lambda=None):
# unfreeze deep layers after unfreeze_epoch epochs
# if self.current_epoch == self.unfreeze_epoch:
# self.unfreeze()
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
# DATA
class SimpleDataset(Dataset):
def __init__(self, sample_rate=8000):
self.sample_rate = sample_rate
def __len__(self):
return 16
def __getitem__(self, idx):
# 0, 1 or 2
target = torch.randint(0, 3, size=(1, )).squeeze()
# size 8000/16000 of 0.0, 0.5, or 1.0
input = torch.full((self.sample_rate,), (target.float()/2).item())
# torch.empty(self.sample_rate,).fill_(target.float()/2)
return input, target
class SimpleDatamodule(pl.LightningDataModule):
def setup(self, stage: str = None):
pass
def train_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
def val_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
# dataset = self._set_dataset_split("val")
# return DataLoader(dataset, batch_size=self.hparams["batch_size"],
# sampler=SubsetRandomSampler(dataset.indices), num_workers=4)
def test_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
if __name__ == '__main__':
sr = 8000
checkpoint_location = "example.ckpt"
# network
model = MyModel(sample_rate=8000, pretrained_hparams=True)
print("After all init, hparams:\n{self.hparams}\n")
# data
dm = SimpleDatamodule()
# train
trainer = Trainer(max_epochs=4, deterministic=True) # gpus=1,
trainer.fit(model, dm)
# save
trainer.save_checkpoint(checkpoint_location)
# check model contents
print(f"\n\nModel save completed. Checking contents saved model...")
checkpoint = torch.load(checkpoint_location) # , map_location='cuda:0'
print(f"Checkpoint hyper parameters:\n{checkpoint['hyper_parameters']}") # .keys() # ['state_dict']
# ERROR: load weights into new model
print("\nContents check completed. Trying to restore model with checkpoint...")
model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
# KeyError: 'sample_rate'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment