Last active
January 14, 2021 15:30
-
-
Save johschmidt42/c4fdfbdfd58cf9ca33425a5b44c7e7df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
class Trainer: | |
def __init__(self, | |
model: torch.nn.Module, | |
device: torch.device, | |
criterion: torch.nn.Module, | |
optimizer: torch.optim.Optimizer, | |
training_DataLoader: torch.utils.data.Dataset, | |
validation_DataLoader: torch.utils.data.Dataset = None, | |
lr_scheduler: torch.optim.lr_scheduler = None, | |
epochs: int = 100, | |
epoch: int = 0, | |
notebook: bool = False | |
): | |
self.model = model | |
self.criterion = criterion | |
self.optimizer = optimizer | |
self.lr_scheduler = lr_scheduler | |
self.training_DataLoader = training_DataLoader | |
self.validation_DataLoader = validation_DataLoader | |
self.device = device | |
self.epochs = epochs | |
self.epoch = epoch | |
self.notebook = notebook | |
self.training_loss = [] | |
self.validation_loss = [] | |
self.learning_rate = [] | |
def run_trainer(self): | |
if self.notebook: | |
from tqdm.notebook import tqdm, trange | |
else: | |
from tqdm import tqdm, trange | |
progressbar = trange(self.epochs, desc='Progress') | |
for i in progressbar: | |
"""Epoch counter""" | |
self.epoch += 1 # epoch counter | |
"""Training block""" | |
self._train() | |
"""Validation block""" | |
if self.validation_DataLoader is not None: | |
self._validate() | |
"""Learning rate scheduler block""" | |
if self.lr_scheduler is not None: | |
if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau': | |
self.lr_scheduler.batch(self.validation_loss[i]) # learning rate scheduler step with validation loss | |
else: | |
self.lr_scheduler.batch() # learning rate scheduler step | |
return self.training_loss, self.validation_loss, self.learning_rate | |
def _train(self): | |
if self.notebook: | |
from tqdm.notebook import tqdm, trange | |
else: | |
from tqdm import tqdm, trange | |
self.model.train() # train mode | |
train_losses = [] # accumulate the losses here | |
batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader), | |
leave=False) | |
for i, (x, y) in batch_iter: | |
input, target = x.to(self.device), y.to(self.device) # send to device (GPU or CPU) | |
self.optimizer.zero_grad() # zerograd the parameters | |
out = self.model(input) # one forward pass | |
loss = self.criterion(out, target) # calculate loss | |
loss_value = loss.item() | |
train_losses.append(loss_value) | |
loss.backward() # one backward pass | |
self.optimizer.step() # update the parameters | |
batch_iter.set_description(f'Training: (loss {loss_value:.4f})') # update progressbar | |
self.training_loss.append(np.mean(train_losses)) | |
self.learning_rate.append(self.optimizer.param_groups[0]['lr']) | |
batch_iter.close() | |
def _validate(self): | |
if self.notebook: | |
from tqdm.notebook import tqdm, trange | |
else: | |
from tqdm import tqdm, trange | |
self.model.eval() # evaluation mode | |
valid_losses = [] # accumulate the losses here | |
batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader), | |
leave=False) | |
for i, (x, y) in batch_iter: | |
input, target = x.to(self.device), y.to(self.device) # send to device (GPU or CPU) | |
with torch.no_grad(): | |
out = self.model(input) | |
loss = self.criterion(out, target) | |
loss_value = loss.item() | |
valid_losses.append(loss_value) | |
batch_iter.set_description(f'Validation: (loss {loss_value:.4f})') | |
self.validation_loss.append(np.mean(valid_losses)) | |
batch_iter.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment