Created January 8, 2022
An implementation of Retinanet.
# RetinaNet Implementation in PyTorch
Implementation of the following paper: [Focal Loss for Dense Object Detection](
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import math\n",
"import copy"
"import torch\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from import DataLoader"
"import torchvision\n",
"from torchvision import transforms, datasets\n",
"from torchvision.transforms import functional as FT\n",
"from torchvision.transforms import transforms as T"
"from PIL import Image\n",
"import os\n",
"import cv2"
"class Compose:\n",
" def __init__(self, transforms):\n",
" self.transforms = transforms\n",
" def __call__(self, image, target):\n",
" for t in self.transforms:\n",
" image, target = t(image, target)\n",
" return image, target"
"class Normalizer(object):\n",
" def __init__(self):\n",
" self.mean = [0.485, 0.456, 0.406]\n",
" self.std = [0.229, 0.224, 0.225]\n",
" self.normalize = T.Compose([T.Normalize(mean=self.mean, std=self.std)])\n",
" def __call__(self, image, target):\n",
" return self.normalize(image), target"
"class Resize(object):\n",
" def __init__(self, size=400):\n",
" self.size = size\n",
" def __call__(self, img, target):\n",
" size = self.size\n",
" boxes = [t['bbox'] for t in target]\n",
" w, h = img.size\n",
" if isinstance(size, int):\n",
" size_min = min(w,h)\n",
" size_max = max(w,h)\n",
" sw = sh = float(size) / size_min\n",
" if sw * size_max > 800:\n",
" sw = sh = float(800) / size_max\n",
" ow = int(w * sw + 0.5)\n",
" oh = int(h * sh + 0.5)\n",
" else:\n",
" ow, oh = size\n",
" sw = float(ow) / w\n",
" sh = float(oh) / h\n",
" boxes = (torch.FloatTensor(boxes)*torch.Tensor([sw,sh,sw,sh])).tolist()\n",
" for t in range(len(target)):\n",
" target[t]['bbox'] = boxes[t]\n",
" return img.resize((ow,oh), Image.BILINEAR), target"
"class ToTensor(nn.Module):\n",
" def forward(\n",
" self, image, target = None\n",
" ):\n",
" image = FT.pil_to_tensor(image)\n",
" image = FT.convert_image_dtype(image)\n",
" return image, target"
"class PILToTensor(nn.Module):\n",
" def forward(\n",
" self, image, target = None\n",
" ):\n",
" image = FT.pil_to_tensor(image)\n",
" return image, target"
"#### COLAB LOADER ####\n",
"# !curl -L \"\" >; unzip; rm\n",
"# Use for colab only #\n",
dataset_path = "/Volumes/Samsung_T5/Documents/MachineLearning/machine_learning_notebooks/pytorch/aquarium-dataset/Aquarium Combined/"
"def xyxy_2_xywh(boxes):\n",
" a = torch.FloatTensor(boxes[:,:2])\n",
" b = torch.FloatTensor(boxes[:,2:])\n",
" return[(a+b)/2,b-a+1], 1)\n",
" \n",
"def xywh_2_xyxy(boxes):\n",
" a = torch.FloatTensor(boxes[:,:2])\n",
" b = torch.FloatTensor(boxes[:,2:])\n",
" return[a-b/2,a+b/2], 1)"
"def box_nms(bboxes, scores, threshold=0.5, mode='union'):\n",
" \n",
" x1 = bboxes[:,0]\n",
" y1 = bboxes[:,1]\n",
" x2 = bboxes[:,2]\n",
" y2 = bboxes[:,3]\n",
" areas = (x2-x1+1) * (y2-y1+1)\n",
" _, order = scores.sort(0, descending=True)\n",
" keep = []\n",
" while order.numel() > 0:\n",
" if order.numel() == 1:\n",
" keep.append(order.item())\n",
" break\n",
" \n",
" i = order[0]\n",
" keep.append(i)\n",
" xx1 = x1[order[1:]].clamp(min=x1[i])\n",
" yy1 = y1[order[1:]].clamp(min=y1[i])\n",
" xx2 = x2[order[1:]].clamp(max=x2[i])\n",
" yy2 = y2[order[1:]].clamp(max=y2[i])\n",
" w = (xx2-xx1+1).clamp(min=0)\n",
" h = (yy2-yy1+1).clamp(min=0)\n",
" inter = w*h\n",
" if mode == 'union':\n",
" ovr = inter / (areas[i] + areas[order[1:]] - inter)\n",
" elif mode == 'min':\n",
" ovr = inter / areas[order[1:]].clamp(max=areas[i])\n",
" else:\n",
" raise TypeError('Unknown nms mode: %s.' % mode)\n",
" ids = (ovr<=threshold).nonzero().squeeze()\n",
" if ids.numel() == 0:\n",
" break\n",
" order = order[ids+1]\n",
" return torch.LongTensor(keep)"
"def iou(box1, box2, order=\"xyxy\"):\n",
" if order == \"xywh\":\n",
" box1 = xywh_2_xyxy(box1)\n",
" box2 = xywh_2_xyxy(box2)\n",
" N = box1.size(0)\n",
" M = box2.size(0)\n",
" lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2]\n",
" rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2]\n",
" wh = (rb-lt+1).clamp(min=0) # [N,M,2]\n",
" inter = wh[:,:,0] * wh[:,:,1] # [N,M]\n",
" area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1) # [N,]\n",
" area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1) # [M,]\n",
" iou = inter / (area1[:,None] + area2 - inter)\n",
" return iou"
"### Anchor Boxes\n",
"class AnchorBox():\n",
" \"\"\"\n",
" Generate anchor boxes for level 3 to level 8\n",
" \"\"\"\n",
" def __init__(self):\n",
" self.ratios = [0.5, 1, 2]\n",
" self.scales = [1, 2**(1/3), 2**(2/3)]\n",
" \n",
" self.A = len(self.ratios) * len(self.scales) # number of anchors (from paper)\n",
" self.areas = [x**2 for x in [32, 64, 128, 256, 512]] # P3, P4, P5, P6, P7\n",
" self.strides = [2 ** i for i in range(3, 8)] # Each layer's feature map is 2^l smaller than the input\n",
" self.anchor_dims = self._anchor_dims()\n",
" ## for feature map sizes\n",
" \n",
" def _meshgrid(self, x, y, row_major=True):\n",
" a = torch.arange(0,x)\n",
" b = torch.arange(0,y)\n",
" xx = a.repeat(y).view(-1,1)\n",
" yy = b.view(-1,1).repeat(1,x).view(-1,1)\n",
" return[xx,yy],1) if row_major else[yy,xx],1)\n",
" \n",
" def _anchor_dims(self):\n",
" anchor_dims = []\n",
" for area in self.areas:\n",
" for ratio in self.ratios:\n",
" anchor_height = math.sqrt(area / ratio)\n",
" anchor_width = area / anchor_height\n",
" \n",
" for scale in self.scales:\n",
" anchor_width = anchor_width * scale\n",
" anchor_height = anchor_height * scale\n",
" anchor_dims.append([anchor_width, anchor_height])\n",
" return torch.FloatTensor(anchor_dims).view(len(self.areas), -1, 2)\n",
" \n",
" def generate_anchor_boxes(self, input_size):\n",
" \"\"\"\n",
" Generates Anchor Boxes\n",
" \n",
" input_size: torch.Tensor: (w, h)\n",
" \"\"\"\n",
" \n",
" num_feature_maps = len(self.areas)\n",
" feature_map_sizes = [(input_size / stride).ceil() for stride in self.strides] # calculating feature map sizes of p3 to p7\n",
" boxes = []\n",
" for i in range(num_feature_maps):\n",
" fm_size = feature_map_sizes[i]\n",
" grid_size = input_size / fm_size\n",
" fm_w, fm_h = int(fm_size[0]), int(fm_size[1])\n",
" xy = self._meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2]\n",
" xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2)\n",
" wh = self.anchor_dims[i].view(1,1,9,2).expand(fm_h,fm_w,9,2)\n",
" box =[xy,wh], 3) # [x,y,w,h]\n",
" boxes.append(box.view(-1,4))\n",
" return, 0)"
"class Encoder:\n",
" def __init__(self):\n",
" self.anchor_box = AnchorBox()\n",
" def encode(self, boxes, labels, input_size):\n",
" input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \\\n",
" else torch.Tensor(input_size)\n",
" anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n",
"# boxes = xyxy_2_xywh(boxes)\n",
" boxes = torch.FloatTensor(boxes)\n",
" \n",
" ious = iou(anchor_boxes, boxes, order=\"xywh\")\n",
" max_ious, max_ids = ious.max(1)\n",
" boxes = boxes[max_ids]\n",
" \n",
" loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:]\n",
" loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:])\n",
" loc_targets =[loc_xy,loc_wh], 1)\n",
" cls_targets = 1 + labels[max_ids]\n",
" cls_targets[max_ious<0.5] = 0\n",
" ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5]\n",
" cls_targets[ignore] = -1 # for now just mark ignored to -1\n",
" return loc_targets, cls_targets\n",
" \n",
" def decode(self, loc_preds, cls_preds, input_size):\n",
" input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) else torch.Tensor(input_size)\n",
" \n",
" CLS_THRESH = 0.5\n",
" NMS_THRESH = 0.5\n",
" anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n",
" loc_xy = loc_preds[:,:2]\n",
" loc_wh = loc_preds[:,2:]\n",
" xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2]\n",
" wh = loc_wh.exp() * anchor_boxes[:,2:]\n",
" boxes =[xy-wh/2, xy+wh/2], 1)\n",
" score, labels = cls_preds.sigmoid().max(1)\n",
" ids = score > CLS_THRESH\n",
" ids = ids.nonzero().squeeze()\n",
" keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH)\n",
" return boxes[ids][keep], labels[ids][keep]"
"class AquariumDetection(datasets.VisionDataset):\n",
" def __init__(\n",
" self,\n",
" root: str,\n",
" split = \"train\",\n",
" transform= None,\n",
" target_transform = None,\n",
" transforms = None,\n",
" ) -> None:\n",
" super().__init__(root, transforms, transform, target_transform)\n",
" self.split = split\n",
" self.coco = COCO(os.path.join(root, split, \"_annotations.coco.json\"))\n",
" self.ids = list(sorted(self.coco.imgs.keys()))\n",
" self.ids = [id for id in self.ids if (len(self._load_target(id)) > 0)]\n",
" def _load_image(self, id: int) -> Image.Image:\n",
" path = self.coco.loadImgs(id)[0][\"file_name\"]\n",
" img =, self.split, path)).convert(\"RGB\")\n",
" return img\n",
" def _load_target(self, id: int):\n",
" return self.coco.loadAnns(self.coco.getAnnIds(id))\n",
" def __getitem__(self, index: int):\n",
" id = self.ids[index]\n",
" image = self._load_image(id)\n",
" target = copy.deepcopy(self._load_target(id))\n",
" if self.transforms is not None:\n",
" image, target = self.transforms(image, target)\n",
" \n",
" annot = [t[\"bbox\"] + [t[\"category_id\"]] for t in target]\n",
" return image, annot\n",
" def __len__(self) -> int:\n",
" return len(self.ids)"
"def collate_fn(batch):\n",
" \"\"\"\n",
" The images in the dataset will be of different sizes. This function takes the images and pads them. Then we encode the images.\n",
" \"\"\"\n",
" imgs = [x[0] for x in batch]\n",
" annots = np.array([x[1] for x in batch], dtype=object)\n",
" widths = [int(s.shape[1]) for s in imgs]\n",
" heights = [int(s.shape[2]) for s in imgs]\n",
" batch_size = len(imgs)\n",
" max_width = np.array(widths).max()\n",
" max_height = np.array(heights).max()\n",
" padded_imgs = torch.zeros(batch_size, max_width, max_height, 3)\n",
" for i in range(batch_size):\n",
" img = imgs[i]\n",
" padded_imgs[i, :int(img.shape[1]), :int(img.shape[2]), :] = img.permute(1, 2, 0)\n",
" padded_imgs = padded_imgs.permute(0, 3, 1, 2)\n",
" \n",
" ## Encode ##\n",
" encoder = Encoder()\n",
" loc_targets = []\n",
" cls_targets = []\n",
" for i in range(len(imgs)):\n",
" annot = annots[i]\n",
" boxes = np.array(annot)[:, 0:4]\n",
" labels = np.array(annot)[:, 4]\n",
" image = padded_imgs[i]\n",
" loc_target, cls_target = encoder.encode(boxes, labels, (image.shape[1], image.shape[2]))\n",
" loc_targets.append(torch.FloatTensor(loc_target))\n",
" cls_targets.append(torch.FloatTensor(cls_target))\n",
" return {'img': padded_imgs, 'loc_targets': torch.stack(loc_targets), 'cls_targets': torch.stack(cls_targets)}"
"def get_transform(train):\n",
" transforms = []\n",
" transforms.append(Resize(size=300))\n",
" transforms.append(ToTensor())\n",
" transforms.append(Normalizer())\n",
" return Compose(transforms)"
"source": [
"train_dataset = AquariumDetection(root=dataset_path, transforms=get_transform(True))\n",
"val_dataset = AquariumDetection(root=dataset_path, split=\"valid\", transforms=get_transform(False))\n",
"test_dataset = AquariumDetection(root=dataset_path, split=\"test\", transforms=get_transform(False))"
"def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n",
" \"\"\"3x3 convolution with padding\"\"\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
" padding=dilation, groups=groups, bias=False, dilation=dilation)\n",
"def conv1x1(in_planes, out_planes, stride=1):\n",
" \"\"\"1x1 convolution\"\"\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)"
"class Bottleneck(nn.Module):\n",
" expansion = 4\n",
" def __init__(self, inplanes, planes, stride=1, groups=1,\n",
" base_width=64, dilation=1):\n",
" super(Bottleneck, self).__init__()\n",
" norm_layer = nn.BatchNorm2d\n",
" width = int(planes * (base_width / 64.)) * groups\n",
" self.conv1 = conv1x1(inplanes, width)\n",
" self.bn1 = norm_layer(width)\n",
" self.conv2 = conv3x3(width, width, stride, groups, dilation)\n",
" self.bn2 = norm_layer(width)\n",
" self.conv3 = conv1x1(width, planes * self.expansion)\n",
" self.bn3 = norm_layer(planes * self.expansion)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" \n",
" self.downsample = nn.Sequential()\n",
" if stride != 1 or inplanes != self.expansion*planes:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(inplanes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(self.expansion*planes)\n",
" )\n",
" \n",
" self.stride = stride\n",
" def forward(self, x):\n",
" identity = x\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv3(out)\n",
" out = self.bn3(out)\n",
" identity = self.downsample(x)\n",
" out += identity\n",
" out = self.relu(out)\n",
" return out\n"
"class FPN(nn.Module):\n",
" def __init__(self, block, num_blocks):\n",
" super(FPN, self).__init__()\n",
" self.in_planes = 64\n",
" \n",
" self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" \n",
" self.conv2 = self._make_layer(block, 64, num_blocks=num_blocks[0], stride=1)\n",
" self.conv3 = self._make_layer(block, 128, num_blocks=num_blocks[1], stride=2)\n",
" self.conv4 = self._make_layer(block, 256, num_blocks=num_blocks[2], stride=2)\n",
" self.conv5 = self._make_layer(block, 512, num_blocks=num_blocks[3], stride=2)\n",
" \n",
" self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1)\n",
" self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)\n",
" \n",
" ## lateral layers ##\n",
" self.lat1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)\n",
" self.lat2 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)\n",
" self.lat3 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)\n",
" \n",
" ## top-down layers ##\n",
" self.topdown1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n",
" self.topdown2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n",
" \n",
" self.relu = nn.ReLU()\n",
" \n",
" def _upsample_and_add(self, x, y):\n",
" _,_,H,W = y.size()\n",
" return F.upsample(x, size=(H,W), mode='bilinear') + y\n",
" \n",
" def _make_layer(self, block, planes, num_blocks, stride):\n",
" strides = [stride] + [1]*(num_blocks-1)\n",
" layers = []\n",
" for stride in strides:\n",
" layers.append(block(self.in_planes, planes, stride))\n",
" self.in_planes = planes * block.expansion\n",
" return nn.Sequential(*layers)\n",
" \n",
" def forward(self, x):\n",
" #bottom up\n",
" c1 = self.relu(self.bn1(self.conv1(x)))\n",
" c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)\n",
" c2 = self.conv2(c1)\n",
" c3 = self.conv3(c2)\n",
" c4 = self.conv4(c3)\n",
" c5 = self.conv5(c4)\n",
" p6 = self.conv6(c5)\n",
" p7 = self.conv7(p6)\n",
" p5 = self.lat1(c5)\n",
" p4 = self._upsample_and_add(p5, self.lat2(c4))\n",
" p4 = self.topdown1(p4)\n",
" p3 = self._upsample_and_add(p4, self.lat3(c3))\n",
" p3 = self.topdown2(p3)\n",
" return p3, p4, p5, p6, p7"
"class ClassificationHead(nn.Module):\n",
" def __init__(self, n_classes=8):\n",
" super(ClassificationHead, self).__init__()\n",
" self.n_anchors = 9\n",
" self.n_classes = n_classes\n",
" \n",
" self.convnet = nn.Sequential(*[\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, self.n_classes*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n",
" ])\n",
" def forward(self, x):\n",
" return self.convnet(x)"
"class RegressionHead(nn.Module):\n",
" def __init__(self):\n",
" super(RegressionHead, self).__init__()\n",
" self.n_anchors = 9\n",
" self.convnet = nn.Sequential(*[\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 4*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n",
" ])\n",
" def forward(self, x):\n",
" return self.convnet(x)"
"class RetinaNet(nn.Module):\n",
" def __init__(self, n_classes=8):\n",
" super(RetinaNet, self).__init__()\n",
" \n",
" self.fpn = FPN(Bottleneck, [3, 4, 6, 3])\n",
" \n",
" self.num_classes = n_classes\n",
" \n",
" self.classification_head = ClassificationHead(n_classes = self.num_classes) # class head\n",
" self.regression_head = RegressionHead() # loc head\n",
" def forward(self, x):\n",
" feature_maps = self.fpn(x) #p3, p4, p5, p6, p7\n",
" \n",
" loc_preds = []\n",
" cls_preds = []\n",
" \n",
" for fmap in feature_maps:\n",
" loc_pred = self.regression_head(fmap)\n",
" cls_pred = self.classification_head(fmap)\n",
" \n",
" loc_pred = loc_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,4) \n",
" cls_pred = cls_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,self.num_classes) \n",
" \n",
" loc_preds.append(loc_pred)\n",
" cls_preds.append(cls_pred)\n",
" \n",
" return, 1),, 1)\n",
" \n",
" def freeze_bn(self):\n",
" '''Freeze BatchNorm layers.'''\n",
" for layer in self.modules():\n",
" if isinstance(layer, nn.BatchNorm2d):\n",
" layer.eval()\n",
" "
"class FocalLoss(nn.Module):\n",
" def __init__(self, n_classes = 8):\n",
" super().__init__()\n",
" self.n_classes = n_classes\n",
" \n",
" def focal_loss(self, x, y):\n",
" alpha = -0.25\n",
" gamma = 2 # Paper recommended values\n",
" \n",
" t = one_hot_embedding(y.cpu(), 1 + self.n_classes)\n",
" t = t[:,1:]\n",
" if torch.cuda.is_available():\n",
" t = t.cuda()\n",
" \n",
" xt = x*(2*t-1)\n",
" pt = (2*xt+1).sigmoid()\n",
" \n",
" w = alpha*t + (1-alpha)*(1-t)\n",
" loss = -w*pt.log() / 2\n",
" return loss.sum()\n",
" \n",
" def forward(self, loc_preds, loc_true, cls_preds, cls_true):\n",
" batch_size, num_boxes = cls_true.size()\n",
" pos = cls_true > 0\n",
" num_pos = pos.long().sum()\n",
" \n",
" ## Loc loss\n",
" mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]\n",
" masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4]\n",
" masked_loc_true = loc_true[mask].view(-1,4) # [#pos,4]\n",
" loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_true, size_average=False)\n",
" ## cls loss\n",
" pos_neg = cls_true > -1 # exclude ignored anchors\n",
" mask = pos_neg.unsqueeze(2).expand_as(cls_preds)\n",
" masked_cls_preds = cls_preds[mask].view(-1,self.n_classes)\n",
" cls_loss = self.focal_loss(masked_cls_preds, cls_true[pos_neg])\n",
" \n",
" loss = (loc_loss + cls_loss) / num_pos\n",
" return loss"
