-
-
Save SannaPersson/b045892f42a96274902bb349126d8c5c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implementation of Yolo Loss Function similar to the one in Yolov3 paper, | |
the difference from what I can tell is I use CrossEntropy for the classes | |
instead of BinaryCrossEntropy. | |
""" | |
import random | |
import torch | |
import torch.nn as nn | |
from utils import intersection_over_union | |
class YoloLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.mse = nn.MSELoss() | |
self.bce = nn.BCEWithLogitsLoss() | |
self.entropy = nn.CrossEntropyLoss() | |
self.sigmoid = nn.Sigmoid() | |
# Constants signifying how much to pay for each respective part of the loss | |
self.lambda_class = 1 | |
self.lambda_noobj = 10 | |
self.lambda_obj = 1 | |
self.lambda_box = 10 | |
def forward(self, predictions, target, anchors): | |
""" | |
:param predictions: output from model of shape: (batch size, anchors on scale , grid size, grid size, 5 + num classes) | |
:param target: targets on particular scale of shape: (batch size, anchors on scale, grid size, grid size, 6) | |
:param anchors: anchor boxes on the particular scale of shape (anchors on scale, 2) | |
:return: returns the loss on the particular scale | |
""" | |
# Check where obj and noobj (we ignore if target == -1) | |
# Here we check where in the label matrix there is an object or not | |
obj = target[..., 0] == 1 # in paper this is Iobj_i | |
noobj = target[..., 0] == 0 # in paper this is Inoobj_i | |
# ======================= # | |
# FOR NO OBJECT LOSS # | |
# ======================= # | |
# The indexing noobj refers to the fact that we only apply the loss where there is no object | |
no_object_loss = self.bce( | |
(predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]), | |
) | |
# ==================== # | |
# FOR OBJECT LOSS # | |
# ==================== # | |
# Here we compute the loss for the cells and anchor boxes that contain an object | |
# Reschape anchors to allow for broadcasting in multiplication below | |
anchors = anchors.reshape(1, 3, 1, 1, 2) | |
# Convert outputs from model to bounding boxes according to formulas in paper | |
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1) | |
# Targets for the object prediction should be the iou of the predicted bounding box and the target bounding box | |
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach() | |
# Only incur loss for the cells where there is an objects signified by indexing with obj | |
object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj])) | |
# ======================== # | |
# FOR BOX COORDINATES # | |
# ======================== # | |
# apply sigmoid to x, y coordinates to convert to bounding boxes | |
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) | |
# to improve gradient flow we convert targets' width and height to the same format as predictions | |
target[..., 3:5] = torch.log( | |
(1e-16 + target[..., 3:5] / anchors) | |
) | |
# compute mse loss for boxes | |
box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj]) | |
# ================== # | |
# FOR CLASS LOSS # | |
# ================== # | |
# here we just apply cross entropy loss as is customary with classification problems | |
class_loss = self.entropy( | |
(predictions[..., 5:][obj]), (target[..., 5][obj].long()), | |
) | |
return ( | |
self.lambda_box * box_loss | |
+ self.lambda_obj * object_loss | |
+ self.lambda_noobj * no_object_loss | |
+ self.lambda_class * class_loss | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment