Created
October 12, 2020 07:47
-
-
Save NumesSanguis/388b4cfab2a8945afa85e8b79cd0c794 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations # don't crash on non-imported libraries when type checking | |
from typing import Dict, Optional, Union | |
import torch | |
from torch import nn as nn | |
import pytorch_lightning as pl | |
from pytorch_lightning import Trainer | |
from torch.utils.data import Dataset, DataLoader | |
from torch.nn import functional as F | |
class MyModel(pl.LightningModule): | |
def __init__(self): # **kwargs # sample_rate | |
super().__init__() | |
# self.save_hyperparameters() | |
# 1 sec of audio | |
self.input_layer = nn.Linear(8000, 400, bias=True) | |
self.hidden_layer = nn.Linear(400, 128, bias=True) | |
self.output_layer = nn.Linear(128, 3, bias=True) | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, input): | |
x = F.relu_(self.input_layer(input)) | |
x = F.relu_(self.hidden_layer(x)) | |
output = self.output_layer(x) # torch.sigmoid() | |
return output | |
def calculate_loss(self, prediction, target): | |
loss = self.criterion(prediction, target) | |
return loss | |
def training_step(self, batch, batch_idx): | |
input, target = batch | |
# !!! use first batch to create an example input !!! | |
if self.example_input_array is None: | |
# we only need 1 sample, not a whole batch | |
self.example_input_array = input | |
prediction = self(input) | |
loss = self.calculate_loss(prediction, target) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
input, target = batch | |
prediction = self(input) | |
loss = self.calculate_loss(prediction, target) | |
return loss | |
def test_step(self, batch, batch_idx): | |
input, target = batch | |
prediction = self(input) | |
loss = self.calculate_loss(prediction, target) | |
return loss | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.parameters(), lr=1e-3) | |
# TODO create pull request for this | |
def to_torchscript( | |
self, file_path: Optional[str] = None, torchscript_approach: Optional[str] = 'script', | |
example_inputs: torch.Tensor = None, **kwargs | |
) -> Union[ScriptModule, Dict[str, ScriptModule]]: | |
# training or eval/test? | |
mode = self.training | |
with torch.no_grad(): | |
if torchscript_approach == 'script': | |
scripted_module = torch.jit.script(self.eval(), **kwargs) | |
elif torchscript_approach == 'trace': | |
if example_inputs is None: | |
example_inputs = self.example_input_array | |
print(f"\n\nExample inputs device: {example_inputs.device}\n") | |
scripted_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) | |
else: | |
raise ValueError(f"torchscript_approach only supports 'script' or 'trace', but value given was:" | |
f"{torchscript_approach}") | |
# set back whether we were training or not | |
self.train(mode) | |
if file_path is not None: | |
torch.jit.save(scripted_module, file_path) | |
return scripted_module | |
# DATA | |
class SimpleDataset(Dataset): | |
def __init__(self, sample_rate=8000): | |
self.sample_rate = sample_rate | |
def __len__(self): | |
return 16 | |
def __getitem__(self, idx): | |
# 0, 1 or 2 | |
target = torch.randint(0, 3, size=(1, )).squeeze() | |
# size 8000/16000 of 0.0, 0.5, or 1.0 | |
input = torch.full((self.sample_rate,), (target.float()/2).item()) | |
# torch.empty(self.sample_rate,).fill_(target.float()/2) | |
return input, target | |
class SimpleDatamodule(pl.LightningDataModule): | |
def setup(self, stage: str = None): | |
pass | |
def train_dataloader(self): | |
return DataLoader(SimpleDataset(), batch_size=4) | |
def val_dataloader(self): | |
return DataLoader(SimpleDataset(), batch_size=4) | |
def test_dataloader(self): | |
return DataLoader(SimpleDataset(), batch_size=4) | |
if __name__ == '__main__': | |
sr = 8000 | |
checkpoint_location = "example.ckpt" | |
# network | |
model = MyModel() | |
# data | |
dm = SimpleDatamodule() | |
# train | |
trainer = Trainer(max_epochs=2, deterministic=True, gpus=1) # gpus=1, | |
trainer.fit(model, dm) | |
# save | |
trainer.save_checkpoint(checkpoint_location) | |
# UNTIL HERE WORKS FINE | |
# save model as TorchScript using eval() | |
model.to_torchscript(file_path="example.pt", torchscript_approach='trace') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment