Skip to content

Instantly share code, notes, and snippets.

@johschmidt42
Last active January 14, 2021 15:30
Show Gist options
  • Save johschmidt42/c4fdfbdfd58cf9ca33425a5b44c7e7df to your computer and use it in GitHub Desktop.
Save johschmidt42/c4fdfbdfd58cf9ca33425a5b44c7e7df to your computer and use it in GitHub Desktop.
import numpy as np
import torch
class Trainer:
def __init__(self,
model: torch.nn.Module,
device: torch.device,
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
training_DataLoader: torch.utils.data.Dataset,
validation_DataLoader: torch.utils.data.Dataset = None,
lr_scheduler: torch.optim.lr_scheduler = None,
epochs: int = 100,
epoch: int = 0,
notebook: bool = False
):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.training_DataLoader = training_DataLoader
self.validation_DataLoader = validation_DataLoader
self.device = device
self.epochs = epochs
self.epoch = epoch
self.notebook = notebook
self.training_loss = []
self.validation_loss = []
self.learning_rate = []
def run_trainer(self):
if self.notebook:
from tqdm.notebook import tqdm, trange
else:
from tqdm import tqdm, trange
progressbar = trange(self.epochs, desc='Progress')
for i in progressbar:
"""Epoch counter"""
self.epoch += 1 # epoch counter
"""Training block"""
self._train()
"""Validation block"""
if self.validation_DataLoader is not None:
self._validate()
"""Learning rate scheduler block"""
if self.lr_scheduler is not None:
if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
self.lr_scheduler.batch(self.validation_loss[i]) # learning rate scheduler step with validation loss
else:
self.lr_scheduler.batch() # learning rate scheduler step
return self.training_loss, self.validation_loss, self.learning_rate
def _train(self):
if self.notebook:
from tqdm.notebook import tqdm, trange
else:
from tqdm import tqdm, trange
self.model.train() # train mode
train_losses = [] # accumulate the losses here
batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader),
leave=False)
for i, (x, y) in batch_iter:
input, target = x.to(self.device), y.to(self.device) # send to device (GPU or CPU)
self.optimizer.zero_grad() # zerograd the parameters
out = self.model(input) # one forward pass
loss = self.criterion(out, target) # calculate loss
loss_value = loss.item()
train_losses.append(loss_value)
loss.backward() # one backward pass
self.optimizer.step() # update the parameters
batch_iter.set_description(f'Training: (loss {loss_value:.4f})') # update progressbar
self.training_loss.append(np.mean(train_losses))
self.learning_rate.append(self.optimizer.param_groups[0]['lr'])
batch_iter.close()
def _validate(self):
if self.notebook:
from tqdm.notebook import tqdm, trange
else:
from tqdm import tqdm, trange
self.model.eval() # evaluation mode
valid_losses = [] # accumulate the losses here
batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader),
leave=False)
for i, (x, y) in batch_iter:
input, target = x.to(self.device), y.to(self.device) # send to device (GPU or CPU)
with torch.no_grad():
out = self.model(input)
loss = self.criterion(out, target)
loss_value = loss.item()
valid_losses.append(loss_value)
batch_iter.set_description(f'Validation: (loss {loss_value:.4f})')
self.validation_loss.append(np.mean(valid_losses))
batch_iter.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment