Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/b045892f42a96274902bb349126d8c5c to your computer and use it in GitHub Desktop.
Save SannaPersson/b045892f42a96274902bb349126d8c5c to your computer and use it in GitHub Desktop.
"""
Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
the difference from what I can tell is I use CrossEntropy for the classes
instead of BinaryCrossEntropy.
"""
import random
import torch
import torch.nn as nn
from utils import intersection_over_union
class YoloLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
self.entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
# Constants signifying how much to pay for each respective part of the loss
self.lambda_class = 1
self.lambda_noobj = 10
self.lambda_obj = 1
self.lambda_box = 10
def forward(self, predictions, target, anchors):
"""
:param predictions: output from model of shape: (batch size, anchors on scale , grid size, grid size, 5 + num classes)
:param target: targets on particular scale of shape: (batch size, anchors on scale, grid size, grid size, 6)
:param anchors: anchor boxes on the particular scale of shape (anchors on scale, 2)
:return: returns the loss on the particular scale
"""
# Check where obj and noobj (we ignore if target == -1)
# Here we check where in the label matrix there is an object or not
obj = target[..., 0] == 1 # in paper this is Iobj_i
noobj = target[..., 0] == 0 # in paper this is Inoobj_i
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
# The indexing noobj refers to the fact that we only apply the loss where there is no object
no_object_loss = self.bce(
(predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
# Here we compute the loss for the cells and anchor boxes that contain an object
# Reschape anchors to allow for broadcasting in multiplication below
anchors = anchors.reshape(1, 3, 1, 1, 2)
# Convert outputs from model to bounding boxes according to formulas in paper
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
# Targets for the object prediction should be the iou of the predicted bounding box and the target bounding box
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
# Only incur loss for the cells where there is an objects signified by indexing with obj
object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
# apply sigmoid to x, y coordinates to convert to bounding boxes
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3])
# to improve gradient flow we convert targets' width and height to the same format as predictions
target[..., 3:5] = torch.log(
(1e-16 + target[..., 3:5] / anchors)
)
# compute mse loss for boxes
box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
# ================== #
# FOR CLASS LOSS #
# ================== #
# here we just apply cross entropy loss as is customary with classification problems
class_loss = self.entropy(
(predictions[..., 5:][obj]), (target[..., 5][obj].long()),
)
return (
self.lambda_box * box_loss
+ self.lambda_obj * object_loss
+ self.lambda_noobj * no_object_loss
+ self.lambda_class * class_loss
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment