Skip to content

Instantly share code, notes, and snippets.

@burrussmp
Created May 5, 2020 19:55
Show Gist options
  • Save burrussmp/fc66d82a1979644697dff2e38465f7b9 to your computer and use it in GitHub Desktop.
Save burrussmp/fc66d82a1979644697dff2e38465f7b9 to your computer and use it in GitHub Desktop.
Training and validation for annotation segmentation network
# used to penalize the model less when it predicts a 0 to account for
# slight frequency issues in the training space i.e. imbalances of label 0 and other labels
weights = np.ones(27)
weights[0] = 0.25
class_weights = torch.FloatTensor(weights).cuda()
"""
Function to train model for a single epoch
@params
model: PyTorch.nn.Module
A segmentation model
device: PyTorch.device
Which device to allocate tensors (GPU, CPU, TPU, etc.)
train_loader: PyTorch.DataLoader
DataLoader class initialized with training data set
optimizer: PyTorch.optimizer
A training optimizer (ex. SGD, Adam, Adagrad, etc)
epoch: int
Epoch # that is currently being evaluated
@return
avg_loss: PyTorch.tensor
Loaded into device (GPU or CPU) and contains the average loss of the epoch specified by the criterion
"""
def train(model, device, train_loader, optimizer, epoch):
model.train() # training mode
criterion = nn.CrossEntropyLoss(class_weights)
total_loss = 0.0
total_tested = 0
for batch_idx, loaded in enumerate(train_loader):# iterate across training dataset using batch size
data = loaded['src'].to(device)
target = loaded['target'].to(device)
optimizer.zero_grad() # set gradients to zero
output = model(data.float()) # get the outputs of the model
loss = criterion(output,target.max(1)[1])
total_loss += loss
loss.backward() # Accumulate the gradient
optimizer.step() # based on currently stored gradient update model params using optomizer rules
total_tested += 1
if batch_idx % 20 == 0: # provide updates on training process
print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tAvg Loss: {:.6f}'.format(
epoch, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), total_loss/(total_tested+1e-8)))
print(torch.sum(output.max(1)[1]!=0))
avg_loss = total_loss/(len(train_loader)+1e-8)
return avg_loss
"""
Function to validate model for a single epoch i.e. does not store the gradients
per pass
@params
model: PyTorch.nn.Module
A segmentation model
device: PyTorch.device
Which device to allocate tensors (GPU, CPU, TPU, etc.)
loader: PyTorch.DataLoader
DataLoader class initialized with validation data set
@return
avg_loss: PyTorch.tensor
Loaded into device (GPU or CPU) and contains the average loss of the epoch specified by the criterion
"""
def validate(model, device, loader):
model.eval() # inference mode
criterion = nn.CrossEntropyLoss()
test_loss = 0.0
with torch.no_grad():
for batch_idx, loaded in enumerate(loader):# iterate across training dataset using batch size
data = loaded['src'].to(device)
target = loaded['target'].to(device)
output = model(data.float()) # collect the outputs
loss = criterion(output,target.max(1)[1])
test_loss += loss
avg_loss = test_loss / (len(loader)+1e-8) # compute the average loss
print('\nTest set: Average loss: {:.4f}\n'.format(
avg_loss))
return avg_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment