Skip to content

Instantly share code, notes, and snippets.

@NumesSanguis
Created October 12, 2020 07:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NumesSanguis/388b4cfab2a8945afa85e8b79cd0c794 to your computer and use it in GitHub Desktop.
Save NumesSanguis/388b4cfab2a8945afa85e8b79cd0c794 to your computer and use it in GitHub Desktop.
from __future__ import annotations # don't crash on non-imported libraries when type checking
from typing import Dict, Optional, Union
import torch
from torch import nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
class MyModel(pl.LightningModule):
def __init__(self): # **kwargs # sample_rate
super().__init__()
# self.save_hyperparameters()
# 1 sec of audio
self.input_layer = nn.Linear(8000, 400, bias=True)
self.hidden_layer = nn.Linear(400, 128, bias=True)
self.output_layer = nn.Linear(128, 3, bias=True)
self.criterion = nn.CrossEntropyLoss()
def forward(self, input):
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
def calculate_loss(self, prediction, target):
loss = self.criterion(prediction, target)
return loss
def training_step(self, batch, batch_idx):
input, target = batch
# !!! use first batch to create an example input !!!
if self.example_input_array is None:
# we only need 1 sample, not a whole batch
self.example_input_array = input
prediction = self(input)
loss = self.calculate_loss(prediction, target)
return loss
def validation_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
return loss
def test_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# TODO create pull request for this
def to_torchscript(
self, file_path: Optional[str] = None, torchscript_approach: Optional[str] = 'script',
example_inputs: torch.Tensor = None, **kwargs
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
# training or eval/test?
mode = self.training
with torch.no_grad():
if torchscript_approach == 'script':
scripted_module = torch.jit.script(self.eval(), **kwargs)
elif torchscript_approach == 'trace':
if example_inputs is None:
example_inputs = self.example_input_array
print(f"\n\nExample inputs device: {example_inputs.device}\n")
scripted_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"torchscript_approach only supports 'script' or 'trace', but value given was:"
f"{torchscript_approach}")
# set back whether we were training or not
self.train(mode)
if file_path is not None:
torch.jit.save(scripted_module, file_path)
return scripted_module
# DATA
class SimpleDataset(Dataset):
def __init__(self, sample_rate=8000):
self.sample_rate = sample_rate
def __len__(self):
return 16
def __getitem__(self, idx):
# 0, 1 or 2
target = torch.randint(0, 3, size=(1, )).squeeze()
# size 8000/16000 of 0.0, 0.5, or 1.0
input = torch.full((self.sample_rate,), (target.float()/2).item())
# torch.empty(self.sample_rate,).fill_(target.float()/2)
return input, target
class SimpleDatamodule(pl.LightningDataModule):
def setup(self, stage: str = None):
pass
def train_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
def val_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
def test_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
if __name__ == '__main__':
sr = 8000
checkpoint_location = "example.ckpt"
# network
model = MyModel()
# data
dm = SimpleDatamodule()
# train
trainer = Trainer(max_epochs=2, deterministic=True, gpus=1) # gpus=1,
trainer.fit(model, dm)
# save
trainer.save_checkpoint(checkpoint_location)
# UNTIL HERE WORKS FINE
# save model as TorchScript using eval()
model.to_torchscript(file_path="example.pt", torchscript_approach='trace')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment