Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Last active October 24, 2021 08:15
Show Gist options
  • Save MLWhiz/2cd4712647f72d4078caf4d76b650717 to your computer and use it in GitHub Desktop.
Save MLWhiz/2cd4712647f72d4078caf4d76b650717 to your computer and use it in GitHub Desktop.
def train(model,
criterion,
optimizer,
train_loader,
valid_loader,
save_file_name,
max_epochs_stop=3,
n_epochs=20,
print_every=1):
"""Train a PyTorch Model
Params
--------
model (PyTorch model): cnn to train
criterion (PyTorch loss): objective to minimize
optimizer (PyTorch optimizier): optimizer to compute gradients of model parameters
train_loader (PyTorch dataloader): training dataloader to iterate through
valid_loader (PyTorch dataloader): validation dataloader used for early stopping
save_file_name (str ending in '.pt'): file path to save the model state dict
max_epochs_stop (int): maximum number of epochs with no improvement in validation loss for early stopping
n_epochs (int): maximum number of training epochs
print_every (int): frequency of epochs to print training stats
Returns
--------
model (PyTorch model): trained cnn with best weights
history (DataFrame): history of train and validation loss and accuracy
"""
# Early stopping intialization
epochs_no_improve = 0
valid_loss_min = np.Inf
valid_max_acc = 0
history = []
# Number of epochs already trained (if using loaded in model weights)
try:
print(f'Model has been trained for: {model.epochs} epochs.\n')
except:
model.epochs = 0
print(f'Starting Training from Scratch.\n')
overall_start = timer()
# Main loop
for epoch in range(n_epochs):
# keep track of training and validation loss each epoch
train_loss = 0.0
valid_loss = 0.0
train_acc = 0
valid_acc = 0
# Set to training
model.train()
start = timer()
# Training loop
for ii, (data, target) in enumerate(train_loader):
# Tensors to gpu
if train_on_gpu:
data, target = data.cuda(), target.cuda()
# Clear gradients
optimizer.zero_grad()
# Predicted outputs are log probabilities
output = model(data)
# Loss and backpropagation of gradients
loss = criterion(output, target)
loss.backward()
# Update the parameters
optimizer.step()
# Track train loss by multiplying average loss by number of examples in batch
train_loss += loss.item() * data.size(0)
# Calculate accuracy by finding max log probability
_, pred = torch.max(output, dim=1)
correct_tensor = pred.eq(target.data.view_as(pred))
# Need to convert correct tensor from int to float to average
accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
# Multiply average accuracy times the number of examples in batch
train_acc += accuracy.item() * data.size(0)
# Track training progress
print(
f'Epoch: {epoch}\t{100 * (ii + 1) / len(train_loader):.2f}% complete. {timer() - start:.2f} seconds elapsed in epoch.',
end='\r')
# After training loops ends, start validation
else:
model.epochs += 1
# Don't need to keep track of gradients
with torch.no_grad():
# Set to evaluation mode
model.eval()
# Validation loop
for data, target in valid_loader:
# Tensors to gpu
if train_on_gpu:
data, target = data.cuda(), target.cuda()
# Forward pass
output = model(data)
# Validation loss
loss = criterion(output, target)
# Multiply average loss times the number of examples in batch
valid_loss += loss.item() * data.size(0)
# Calculate validation accuracy
_, pred = torch.max(output, dim=1)
correct_tensor = pred.eq(target.data.view_as(pred))
accuracy = torch.mean(
correct_tensor.type(torch.FloatTensor))
# Multiply average accuracy times the number of examples
valid_acc += accuracy.item() * data.size(0)
# Calculate average losses
train_loss = train_loss / len(train_loader.dataset)
valid_loss = valid_loss / len(valid_loader.dataset)
# Calculate average accuracy
train_acc = train_acc / len(train_loader.dataset)
valid_acc = valid_acc / len(valid_loader.dataset)
history.append([train_loss, valid_loss, train_acc, valid_acc])
# Print training and validation results
if (epoch + 1) % print_every == 0:
print(
f'\nEpoch: {epoch} \tTraining Loss: {train_loss:.4f} \tValidation Loss: {valid_loss:.4f}'
)
print(
f'\t\tTraining Accuracy: {100 * train_acc:.2f}%\t Validation Accuracy: {100 * valid_acc:.2f}%'
)
# Save the model if validation loss decreases
if valid_loss < valid_loss_min:
# Save model
torch.save(model.state_dict(), save_file_name)
# Track improvement
epochs_no_improve = 0
valid_loss_min = valid_loss
valid_best_acc = valid_acc
best_epoch = epoch
# Otherwise increment count of epochs with no improvement
else:
epochs_no_improve += 1
# Trigger early stopping
if epochs_no_improve >= max_epochs_stop:
print(
f'\nEarly Stopping! Total epochs: {epoch}. Best epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%'
)
total_time = timer() - overall_start
print(
f'{total_time:.2f} total seconds elapsed. {total_time / (epoch+1):.2f} seconds per epoch.'
)
# Load the best state dict
model.load_state_dict(torch.load(save_file_name))
# Attach the optimizer
model.optimizer = optimizer
# Format history
history = pd.DataFrame(
history,
columns=[
'train_loss', 'valid_loss', 'train_acc',
'valid_acc'
])
return model, history
# Attach the optimizer
model.optimizer = optimizer
# Record overall time and print out stats
total_time = timer() - overall_start
print(
f'\nBest epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%'
)
print(
f'{total_time:.2f} total seconds elapsed. {total_time / (epoch):.2f} seconds per epoch.'
)
# Format history
history = pd.DataFrame(
history,
columns=['train_loss', 'valid_loss', 'train_acc', 'valid_acc'])
return model, history
# Running the model
model, history = train(
model,
criterion,
optimizer,
dataloaders['train'],
dataloaders['val'],
save_file_name=save_file_name,
max_epochs_stop=3,
n_epochs=100,
print_every=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment