Skip to content

Instantly share code, notes, and snippets.

@SharanSMenon
Created January 8, 2022 20:46
Show Gist options
  • Save SharanSMenon/53bc826c286aafb1f097ae4511f98ec6 to your computer and use it in GitHub Desktop.
Save SharanSMenon/53bc826c286aafb1f097ae4511f98ec6 to your computer and use it in GitHub Desktop.
An implementation of Retinanet.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "a2c3541f-8c5f-490d-8ada-2d9264a72074",
"metadata": {},
"source": [
"# RetinaNet Implementation in PyTorch\n",
"\n",
"Implementation of the following paper: [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf)"
]
},
{
"cell_type": "markdown",
"id": "cc7a980f-fddb-458f-b48e-fd0f12de3e58",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8f124e21-1f07-4bfd-89e0-fdff09aa4e0d",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import math\n",
"import copy"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6e523509-e024-4df9-82eb-c64df1d60110",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c3bfdb2b-b2a2-4570-93c2-9a9a89a5615e",
"metadata": {},
"outputs": [],
"source": [
"import torchvision\n",
"from torchvision import transforms, datasets\n",
"from torchvision.transforms import functional as FT\n",
"from torchvision.transforms import transforms as T"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "feb9f59e-1e07-4810-9408-58b23cf27c18",
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"import os\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f13fa1a6-d707-4f58-9202-9f65fa67a41f",
"metadata": {},
"outputs": [],
"source": [
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8318d1c0-75da-44ab-9f01-521d7b0fd738",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('1.9.0', '0.10.0')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__, torchvision.__version__"
]
},
{
"cell_type": "markdown",
"id": "bd6b3cf8-65e3-436d-bf20-c553dd592f32",
"metadata": {},
"source": [
"## Transforms and Utilities"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "df785a0e-d026-4e2d-b2d6-f95b81db6757",
"metadata": {},
"outputs": [],
"source": [
"class Compose:\n",
" def __init__(self, transforms):\n",
" self.transforms = transforms\n",
"\n",
" def __call__(self, image, target):\n",
" for t in self.transforms:\n",
" image, target = t(image, target)\n",
" return image, target"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "871e66f6-bc56-4095-90f0-2a5f603fd29e",
"metadata": {},
"outputs": [],
"source": [
"class Normalizer(object):\n",
"\n",
" def __init__(self):\n",
" self.mean = [0.485, 0.456, 0.406]\n",
" self.std = [0.229, 0.224, 0.225]\n",
" self.normalize = T.Compose([T.Normalize(mean=self.mean, std=self.std)])\n",
"\n",
" def __call__(self, image, target):\n",
"\n",
" return self.normalize(image), target"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1d953822-6c79-4511-bd1c-19a84633f703",
"metadata": {},
"outputs": [],
"source": [
"class Resize(object):\n",
" def __init__(self, size=400):\n",
" self.size = size\n",
" def __call__(self, img, target):\n",
" size = self.size\n",
" boxes = [t['bbox'] for t in target]\n",
" w, h = img.size\n",
" if isinstance(size, int):\n",
" size_min = min(w,h)\n",
" size_max = max(w,h)\n",
" sw = sh = float(size) / size_min\n",
" if sw * size_max > 800:\n",
" sw = sh = float(800) / size_max\n",
" ow = int(w * sw + 0.5)\n",
" oh = int(h * sh + 0.5)\n",
" else:\n",
" ow, oh = size\n",
" sw = float(ow) / w\n",
" sh = float(oh) / h\n",
" boxes = (torch.FloatTensor(boxes)*torch.Tensor([sw,sh,sw,sh])).tolist()\n",
" for t in range(len(target)):\n",
" target[t]['bbox'] = boxes[t]\n",
" return img.resize((ow,oh), Image.BILINEAR), target"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b3a9d7e4-9fee-4ed0-a205-892721488fb8",
"metadata": {},
"outputs": [],
"source": [
"class ToTensor(nn.Module):\n",
" def forward(\n",
" self, image, target = None\n",
" ):\n",
" image = FT.pil_to_tensor(image)\n",
" image = FT.convert_image_dtype(image)\n",
" return image, target"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "246f31f3-e29a-4b23-9f14-7c15468f94e8",
"metadata": {},
"outputs": [],
"source": [
"class PILToTensor(nn.Module):\n",
" def forward(\n",
" self, image, target = None\n",
" ):\n",
" image = FT.pil_to_tensor(image)\n",
" return image, target"
]
},
{
"cell_type": "markdown",
"id": "0a2ecfe0-6bda-4ff0-8de9-d8f5cfacb5ba",
"metadata": {},
"source": [
"## Dataset"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bb23610c-6a83-4236-9cae-0cfbad97d9de",
"metadata": {},
"outputs": [],
"source": [
"#### COLAB LOADER ####\n",
"# !curl -L \"https://public.roboflow.com/ds/L6PD1uTSPF?key=Gq3tCeIqHA\" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip\n",
"# Use for colab only #\n",
"######################"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "57619344-33c5-40fc-8d34-ef7529487d00",
"metadata": {},
"outputs": [],
"source": [
"from pycocotools.coco import COCO"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c7bab354-9b31-497d-834c-1ebf5d0ecf9f",
"metadata": {},
"outputs": [],
"source": [
"dataset_path = \"/Volumes/Samsung_T5/Documents/MachineLearning/machine_learning_notebooks/pytorch/aquarium-dataset/Aquarium Combined/\""
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "96370ec5-bc2d-475c-9b94-8a228cf3c982",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading annotations into memory...\n",
"Done (t=0.02s)\n",
"creating index...\n",
"index created!\n"
]
},
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"coc = COCO(os.path.join(dataset_path, \"train\", \"_annotations.coco.json\"))\n",
"categories = coc.cats\n",
"n_classes = len(categories.keys())\n",
"n_classes"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "12a71e96-e1cf-41ea-b8d0-5ce1b14d3ff7",
"metadata": {},
"outputs": [],
"source": [
"def xyxy_2_xywh(boxes):\n",
" a = torch.FloatTensor(boxes[:,:2])\n",
" b = torch.FloatTensor(boxes[:,2:])\n",
" return torch.cat([(a+b)/2,b-a+1], 1)\n",
" \n",
"def xywh_2_xyxy(boxes):\n",
" a = torch.FloatTensor(boxes[:,:2])\n",
" b = torch.FloatTensor(boxes[:,2:])\n",
" return torch.cat([a-b/2,a+b/2], 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "826dfb96-824a-44e5-9268-249c0756e7ae",
"metadata": {},
"outputs": [],
"source": [
"def box_nms(bboxes, scores, threshold=0.5, mode='union'):\n",
" \n",
" x1 = bboxes[:,0]\n",
" y1 = bboxes[:,1]\n",
" x2 = bboxes[:,2]\n",
" y2 = bboxes[:,3]\n",
"\n",
" areas = (x2-x1+1) * (y2-y1+1)\n",
" _, order = scores.sort(0, descending=True)\n",
"\n",
" keep = []\n",
" while order.numel() > 0:\n",
" if order.numel() == 1:\n",
" keep.append(order.item())\n",
" break\n",
" \n",
" i = order[0]\n",
" keep.append(i)\n",
"\n",
" xx1 = x1[order[1:]].clamp(min=x1[i])\n",
" yy1 = y1[order[1:]].clamp(min=y1[i])\n",
" xx2 = x2[order[1:]].clamp(max=x2[i])\n",
" yy2 = y2[order[1:]].clamp(max=y2[i])\n",
"\n",
" w = (xx2-xx1+1).clamp(min=0)\n",
" h = (yy2-yy1+1).clamp(min=0)\n",
" inter = w*h\n",
"\n",
" if mode == 'union':\n",
" ovr = inter / (areas[i] + areas[order[1:]] - inter)\n",
" elif mode == 'min':\n",
" ovr = inter / areas[order[1:]].clamp(max=areas[i])\n",
" else:\n",
" raise TypeError('Unknown nms mode: %s.' % mode)\n",
"\n",
" ids = (ovr<=threshold).nonzero().squeeze()\n",
" if ids.numel() == 0:\n",
" break\n",
" order = order[ids+1]\n",
" return torch.LongTensor(keep)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "88ac1227-a417-403a-8cac-6b04faadc851",
"metadata": {},
"outputs": [],
"source": [
"def iou(box1, box2, order=\"xyxy\"):\n",
" if order == \"xywh\":\n",
" box1 = xywh_2_xyxy(box1)\n",
" box2 = xywh_2_xyxy(box2)\n",
" N = box1.size(0)\n",
" M = box2.size(0)\n",
"\n",
" lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2]\n",
" rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2]\n",
"\n",
" wh = (rb-lt+1).clamp(min=0) # [N,M,2]\n",
" inter = wh[:,:,0] * wh[:,:,1] # [N,M]\n",
"\n",
" area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1) # [N,]\n",
" area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1) # [M,]\n",
" iou = inter / (area1[:,None] + area2 - inter)\n",
" return iou"
]
},
{
"cell_type": "markdown",
"id": "f313d08d-9c3b-400b-9b8b-28cbd5ccdec9",
"metadata": {},
"source": [
"### Anchor Boxes\n",
"\n",
"\"*Anchor boxes have areas of $32^2$ to $512^2$ on pyramid levels $P_3$ to $P_7$.*\" (Page 4, Focal Loss for Dense Object Detection)\n",
"\n",
"- Aspect ratios: $\\{1:2, 1:1, 2:1\\}$, translates to `[0.5, 1, 2]` in python\n",
"- Scales: $\\{2^0, 2^{1/3}, 2^{2/3}\\}$\n",
"\n",
"There should be a total of $A=9$ anchors per level"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "359ef414-00f1-487d-9269-6d2316b10e80",
"metadata": {},
"outputs": [],
"source": [
"class AnchorBox():\n",
" \"\"\"\n",
" Generate anchor boxes for level 3 to level 8\n",
" \"\"\"\n",
" def __init__(self):\n",
" self.ratios = [0.5, 1, 2]\n",
" self.scales = [1, 2**(1/3), 2**(2/3)]\n",
" \n",
" self.A = len(self.ratios) * len(self.scales) # number of anchors (from paper)\n",
" self.areas = [x**2 for x in [32, 64, 128, 256, 512]] # P3, P4, P5, P6, P7\n",
" self.strides = [2 ** i for i in range(3, 8)] # Each layer's feature map is 2^l smaller than the input\n",
" self.anchor_dims = self._anchor_dims()\n",
" ## for feature map sizes\n",
" \n",
" def _meshgrid(self, x, y, row_major=True):\n",
" a = torch.arange(0,x)\n",
" b = torch.arange(0,y)\n",
" xx = a.repeat(y).view(-1,1)\n",
" yy = b.view(-1,1).repeat(1,x).view(-1,1)\n",
" return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1)\n",
" \n",
" def _anchor_dims(self):\n",
" anchor_dims = []\n",
" for area in self.areas:\n",
" for ratio in self.ratios:\n",
" anchor_height = math.sqrt(area / ratio)\n",
" anchor_width = area / anchor_height\n",
" \n",
" for scale in self.scales:\n",
" anchor_width = anchor_width * scale\n",
" anchor_height = anchor_height * scale\n",
" anchor_dims.append([anchor_width, anchor_height])\n",
" return torch.FloatTensor(anchor_dims).view(len(self.areas), -1, 2)\n",
" \n",
" def generate_anchor_boxes(self, input_size):\n",
" \"\"\"\n",
" Generates Anchor Boxes\n",
" \n",
" input_size: torch.Tensor: (w, h)\n",
" \"\"\"\n",
" \n",
" num_feature_maps = len(self.areas)\n",
" feature_map_sizes = [(input_size / stride).ceil() for stride in self.strides] # calculating feature map sizes of p3 to p7\n",
" boxes = []\n",
" for i in range(num_feature_maps):\n",
" fm_size = feature_map_sizes[i]\n",
" grid_size = input_size / fm_size\n",
" fm_w, fm_h = int(fm_size[0]), int(fm_size[1])\n",
" xy = self._meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2]\n",
" xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2)\n",
" wh = self.anchor_dims[i].view(1,1,9,2).expand(fm_h,fm_w,9,2)\n",
" box = torch.cat([xy,wh], 3) # [x,y,w,h]\n",
" boxes.append(box.view(-1,4))\n",
" return torch.cat(boxes, 0)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "2d71e06b-3ce1-41f1-afae-034a62bc98b6",
"metadata": {},
"outputs": [],
"source": [
"class Encoder:\n",
" def __init__(self):\n",
" self.anchor_box = AnchorBox()\n",
" def encode(self, boxes, labels, input_size):\n",
" input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \\\n",
" else torch.Tensor(input_size)\n",
" anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n",
"# boxes = xyxy_2_xywh(boxes)\n",
" boxes = torch.FloatTensor(boxes)\n",
" \n",
" ious = iou(anchor_boxes, boxes, order=\"xywh\")\n",
" max_ious, max_ids = ious.max(1)\n",
" boxes = boxes[max_ids]\n",
" \n",
" loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:]\n",
" loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:])\n",
" loc_targets = torch.cat([loc_xy,loc_wh], 1)\n",
" cls_targets = 1 + labels[max_ids]\n",
"\n",
" cls_targets[max_ious<0.5] = 0\n",
" ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5]\n",
" cls_targets[ignore] = -1 # for now just mark ignored to -1\n",
" return loc_targets, cls_targets\n",
" \n",
" def decode(self, loc_preds, cls_preds, input_size):\n",
" input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) else torch.Tensor(input_size)\n",
" \n",
" CLS_THRESH = 0.5\n",
" NMS_THRESH = 0.5\n",
" anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n",
"\n",
" loc_xy = loc_preds[:,:2]\n",
" loc_wh = loc_preds[:,2:]\n",
"\n",
" xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2]\n",
" wh = loc_wh.exp() * anchor_boxes[:,2:]\n",
" boxes = torch.cat([xy-wh/2, xy+wh/2], 1)\n",
"\n",
" score, labels = cls_preds.sigmoid().max(1)\n",
" ids = score > CLS_THRESH\n",
" ids = ids.nonzero().squeeze()\n",
" keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH)\n",
" return boxes[ids][keep], labels[ids][keep]"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "91eb2fc1-4fe1-4558-8742-619b89286361",
"metadata": {},
"outputs": [],
"source": [
"class AquariumDetection(datasets.VisionDataset):\n",
" def __init__(\n",
" self,\n",
" root: str,\n",
" split = \"train\",\n",
" transform= None,\n",
" target_transform = None,\n",
" transforms = None,\n",
" ) -> None:\n",
" super().__init__(root, transforms, transform, target_transform)\n",
" self.split = split\n",
" self.coco = COCO(os.path.join(root, split, \"_annotations.coco.json\"))\n",
" self.ids = list(sorted(self.coco.imgs.keys()))\n",
" self.ids = [id for id in self.ids if (len(self._load_target(id)) > 0)]\n",
"\n",
" def _load_image(self, id: int) -> Image.Image:\n",
" path = self.coco.loadImgs(id)[0][\"file_name\"]\n",
" img = Image.open(os.path.join(self.root, self.split, path)).convert(\"RGB\")\n",
" return img\n",
"\n",
" def _load_target(self, id: int):\n",
" return self.coco.loadAnns(self.coco.getAnnIds(id))\n",
"\n",
" def __getitem__(self, index: int):\n",
" id = self.ids[index]\n",
" image = self._load_image(id)\n",
" target = copy.deepcopy(self._load_target(id))\n",
"\n",
" if self.transforms is not None:\n",
" image, target = self.transforms(image, target)\n",
" \n",
" annot = [t[\"bbox\"] + [t[\"category_id\"]] for t in target]\n",
"\n",
" return image, annot\n",
"\n",
"\n",
" def __len__(self) -> int:\n",
" return len(self.ids)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "3701d805-c03a-410c-a51a-0f7305d1cf35",
"metadata": {},
"outputs": [],
"source": [
"def collate_fn(batch):\n",
" \"\"\"\n",
" The images in the dataset will be of different sizes. This function takes the images and pads them. Then we encode the images.\n",
" \"\"\"\n",
" imgs = [x[0] for x in batch]\n",
" annots = np.array([x[1] for x in batch], dtype=object)\n",
"\n",
" widths = [int(s.shape[1]) for s in imgs]\n",
" heights = [int(s.shape[2]) for s in imgs]\n",
" batch_size = len(imgs)\n",
"\n",
" max_width = np.array(widths).max()\n",
" max_height = np.array(heights).max()\n",
"\n",
" padded_imgs = torch.zeros(batch_size, max_width, max_height, 3)\n",
"\n",
" for i in range(batch_size):\n",
" img = imgs[i]\n",
" padded_imgs[i, :int(img.shape[1]), :int(img.shape[2]), :] = img.permute(1, 2, 0)\n",
" padded_imgs = padded_imgs.permute(0, 3, 1, 2)\n",
" \n",
" ## Encode ##\n",
" encoder = Encoder()\n",
" loc_targets = []\n",
" cls_targets = []\n",
" for i in range(len(imgs)):\n",
" annot = annots[i]\n",
" boxes = np.array(annot)[:, 0:4]\n",
" labels = np.array(annot)[:, 4]\n",
" image = padded_imgs[i]\n",
" loc_target, cls_target = encoder.encode(boxes, labels, (image.shape[1], image.shape[2]))\n",
" loc_targets.append(torch.FloatTensor(loc_target))\n",
" cls_targets.append(torch.FloatTensor(cls_target))\n",
" return {'img': padded_imgs, 'loc_targets': torch.stack(loc_targets), 'cls_targets': torch.stack(cls_targets)}"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "351a5ec0-56e9-4c56-a1b0-7adb6543e511",
"metadata": {},
"outputs": [],
"source": [
"def get_transform(train):\n",
" transforms = []\n",
" transforms.append(Resize(size=300))\n",
" transforms.append(ToTensor())\n",
" transforms.append(Normalizer())\n",
" return Compose(transforms)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "c693c3ab-0176-4c06-b2a7-309d24d8bb31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading annotations into memory...\n",
"Done (t=0.02s)\n",
"creating index...\n",
"index created!\n",
"loading annotations into memory...\n",
"Done (t=0.00s)\n",
"creating index...\n",
"index created!\n",
"loading annotations into memory...\n",
"Done (t=0.00s)\n",
"creating index...\n",
"index created!\n"
]
}
],
"source": [
"train_dataset = AquariumDetection(root=dataset_path, transforms=get_transform(True))\n",
"val_dataset = AquariumDetection(root=dataset_path, split=\"valid\", transforms=get_transform(False))\n",
"test_dataset = AquariumDetection(root=dataset_path, split=\"test\", transforms=get_transform(False))"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "6a890ca3-b3a6-4f6a-a4f3-6fdf3a518cb4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn)\n",
"val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)\n",
"test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "a5e86435-ff29-4cbe-bc47-078dc15c8de7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(56, 16, 8)"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_loader), len(val_loader), len(test_loader)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "d9332820-f53f-4a3a-b6cc-aa8c0df89596",
"metadata": {},
"outputs": [],
"source": [
"for i in range(len(train_dataset)):\n",
" _ = train_dataset[i]"
]
},
{
"cell_type": "markdown",
"id": "d4cf6818-a90d-435c-b88c-8c6b07f03248",
"metadata": {
"tags": []
},
"source": [
"## Retinanet Implementation"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "ad44cca1-4e59-4058-9ddf-e9743650fb58",
"metadata": {},
"outputs": [],
"source": [
"def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n",
" \"\"\"3x3 convolution with padding\"\"\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
" padding=dilation, groups=groups, bias=False, dilation=dilation)\n",
"\n",
"\n",
"def conv1x1(in_planes, out_planes, stride=1):\n",
" \"\"\"1x1 convolution\"\"\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)"
]
},
{
"cell_type": "code",
"execution_count": 168,
"id": "cb82dba8-cc96-49f0-aba1-fbf83775b8e8",
"metadata": {},
"outputs": [],
"source": [
"class Bottleneck(nn.Module):\n",
" expansion = 4\n",
" def __init__(self, inplanes, planes, stride=1, groups=1,\n",
" base_width=64, dilation=1):\n",
" super(Bottleneck, self).__init__()\n",
" norm_layer = nn.BatchNorm2d\n",
" width = int(planes * (base_width / 64.)) * groups\n",
" self.conv1 = conv1x1(inplanes, width)\n",
" self.bn1 = norm_layer(width)\n",
" self.conv2 = conv3x3(width, width, stride, groups, dilation)\n",
" self.bn2 = norm_layer(width)\n",
" self.conv3 = conv1x1(width, planes * self.expansion)\n",
" self.bn3 = norm_layer(planes * self.expansion)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" \n",
" self.downsample = nn.Sequential()\n",
" if stride != 1 or inplanes != self.expansion*planes:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(inplanes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(self.expansion*planes)\n",
" )\n",
" \n",
" self.stride = stride\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv3(out)\n",
" out = self.bn3(out)\n",
"\n",
" identity = self.downsample(x)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
"\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 169,
"id": "5ca298d4-9375-41c1-b66f-8d60640b5581",
"metadata": {},
"outputs": [],
"source": [
"class FPN(nn.Module):\n",
" def __init__(self, block, num_blocks):\n",
" super(FPN, self).__init__()\n",
" self.in_planes = 64\n",
" \n",
" self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" \n",
" self.conv2 = self._make_layer(block, 64, num_blocks=num_blocks[0], stride=1)\n",
" self.conv3 = self._make_layer(block, 128, num_blocks=num_blocks[1], stride=2)\n",
" self.conv4 = self._make_layer(block, 256, num_blocks=num_blocks[2], stride=2)\n",
" self.conv5 = self._make_layer(block, 512, num_blocks=num_blocks[3], stride=2)\n",
" \n",
" self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1)\n",
" self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)\n",
" \n",
" ## lateral layers ##\n",
" self.lat1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)\n",
" self.lat2 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)\n",
" self.lat3 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)\n",
" \n",
" ## top-down layers ##\n",
" self.topdown1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n",
" self.topdown2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n",
" \n",
" self.relu = nn.ReLU()\n",
" \n",
" def _upsample_and_add(self, x, y):\n",
" _,_,H,W = y.size()\n",
" return F.upsample(x, size=(H,W), mode='bilinear') + y\n",
" \n",
" def _make_layer(self, block, planes, num_blocks, stride):\n",
" strides = [stride] + [1]*(num_blocks-1)\n",
" layers = []\n",
" for stride in strides:\n",
" layers.append(block(self.in_planes, planes, stride))\n",
" self.in_planes = planes * block.expansion\n",
" return nn.Sequential(*layers)\n",
" \n",
" def forward(self, x):\n",
" #bottom up\n",
" c1 = self.relu(self.bn1(self.conv1(x)))\n",
" c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)\n",
" c2 = self.conv2(c1)\n",
" c3 = self.conv3(c2)\n",
" c4 = self.conv4(c3)\n",
" c5 = self.conv5(c4)\n",
" p6 = self.conv6(c5)\n",
" p7 = self.conv7(p6)\n",
" p5 = self.lat1(c5)\n",
" p4 = self._upsample_and_add(p5, self.lat2(c4))\n",
" p4 = self.topdown1(p4)\n",
" p3 = self._upsample_and_add(p4, self.lat3(c3))\n",
" p3 = self.topdown2(p3)\n",
" return p3, p4, p5, p6, p7"
]
},
{
"cell_type": "code",
"execution_count": 170,
"id": "5d228fae-537c-4ac6-a3b5-1e838112eb3f",
"metadata": {},
"outputs": [],
"source": [
"class ClassificationHead(nn.Module):\n",
" def __init__(self, n_classes=8):\n",
" super(ClassificationHead, self).__init__()\n",
" self.n_anchors = 9\n",
" self.n_classes = n_classes\n",
" \n",
" self.convnet = nn.Sequential(*[\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, self.n_classes*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n",
" ])\n",
" def forward(self, x):\n",
" return self.convnet(x)"
]
},
{
"cell_type": "code",
"execution_count": 171,
"id": "4c57cc78-6615-45a3-93ce-242157536e0f",
"metadata": {},
"outputs": [],
"source": [
"class RegressionHead(nn.Module):\n",
" def __init__(self):\n",
" super(RegressionHead, self).__init__()\n",
" self.n_anchors = 9\n",
" self.convnet = nn.Sequential(*[\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
" nn.ReLU(True),\n",
" nn.Conv2d(256, 4*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n",
" ])\n",
" def forward(self, x):\n",
" return self.convnet(x)"
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "86f363cc-7c90-4be0-8522-89cc96d087a2",
"metadata": {},
"outputs": [],
"source": [
"class RetinaNet(nn.Module):\n",
" def __init__(self, n_classes=8):\n",
" super(RetinaNet, self).__init__()\n",
" \n",
" self.fpn = FPN(Bottleneck, [3, 4, 6, 3])\n",
" \n",
" self.num_classes = n_classes\n",
" \n",
" self.classification_head = ClassificationHead(n_classes = self.num_classes) # class head\n",
" self.regression_head = RegressionHead() # loc head\n",
" def forward(self, x):\n",
" feature_maps = self.fpn(x) #p3, p4, p5, p6, p7\n",
" \n",
" loc_preds = []\n",
" cls_preds = []\n",
" \n",
" for fmap in feature_maps:\n",
" loc_pred = self.regression_head(fmap)\n",
" cls_pred = self.classification_head(fmap)\n",
" \n",
" loc_pred = loc_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,4) \n",
" cls_pred = cls_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,self.num_classes) \n",
" \n",
" loc_preds.append(loc_pred)\n",
" cls_preds.append(cls_pred)\n",
" \n",
" return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1)\n",
" \n",
" def freeze_bn(self):\n",
" '''Freeze BatchNorm layers.'''\n",
" for layer in self.modules():\n",
" if isinstance(layer, nn.BatchNorm2d):\n",
" layer.eval()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 173,
"id": "767203b7-fbdb-4708-a942-42022b019bc8",
"metadata": {},
"outputs": [],
"source": [
"net = RetinaNet()"
]
},
{
"cell_type": "code",
"execution_count": 176,
"id": "233b4698-e07b-4430-9c88-e3b1f1106e75",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([8, 30231, 4]), torch.Size([8, 30231, 8]))"
]
},
"execution_count": 176,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch = next(iter(train_loader))\n",
"loc_preds, cls_preds = net(batch['img'])\n",
"loc_preds.shape, cls_preds.shape # The 2nd number should be the same as the number of anchors per image."
]
},
{
"cell_type": "code",
"execution_count": 179,
"id": "6abb61c1-03e4-4530-b56b-e401965c5917",
"metadata": {},
"outputs": [],
"source": [
"encoder = Encoder()\n",
"# _ = encoder.decode(loc_preds[0], cls_preds[0], tuple(batch['img'].shape[2:]))\n",
"## Ensure this cell just runs ##"
]
},
{
"cell_type": "markdown",
"id": "414de16f-fee8-4d7e-87c3-2f9484911786",
"metadata": {},
"source": [
"## Focal Loss\n",
"\n",
"An extension of Cross Entropy\n",
"$$\n",
"FL(p_t) = -\\alpha(1-p_t)^{\\gamma}log(p_t)\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 180,
"id": "ab44b1d4-fd07-4334-b84e-adb1bd4b9ee0",
"metadata": {},
"outputs": [],
"source": [
"def one_hot_embedding(labels, num_classes):\n",
" y = torch.eye(num_classes)\n",
" return y[labels]"
]
},
{
"cell_type": "code",
"execution_count": 222,
"id": "fc16aa9b-f88e-461b-bdb6-3a39202a4330",
"metadata": {},
"outputs": [],
"source": [
"class FocalLoss(nn.Module):\n",
" def __init__(self, n_classes = 8):\n",
" super().__init__()\n",
" self.n_classes = n_classes\n",
" \n",
" def focal_loss(self, x, y):\n",
" alpha = -0.25\n",
" gamma = 2 # Paper recommended values\n",
" \n",
" t = one_hot_embedding(y.cpu(), 1 + self.n_classes)\n",
" t = t[:,1:]\n",
" if torch.cuda.is_available():\n",
" t = t.cuda()\n",
" \n",
" xt = x*(2*t-1)\n",
" pt = (2*xt+1).sigmoid()\n",
" \n",
" w = alpha*t + (1-alpha)*(1-t)\n",
" loss = -w*pt.log() / 2\n",
" return loss.sum()\n",
" \n",
" def forward(self, loc_preds, loc_true, cls_preds, cls_true):\n",
" batch_size, num_boxes = cls_true.size()\n",
" pos = cls_true > 0\n",
" num_pos = pos.long().sum()\n",
" \n",
" ## Loc loss\n",
" mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]\n",
" masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4]\n",
" masked_loc_true = loc_true[mask].view(-1,4) # [#pos,4]\n",
" loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_true, size_average=False)\n",
" ## cls loss\n",
" pos_neg = cls_true > -1 # exclude ignored anchors\n",
" mask = pos_neg.unsqueeze(2).expand_as(cls_preds)\n",
" masked_cls_preds = cls_preds[mask].view(-1,self.n_classes)\n",
" cls_loss = self.focal_loss(masked_cls_preds, cls_true[pos_neg])\n",
" \n",
" loss = (loc_loss + cls_loss) / num_pos\n",
" return loss"
]
},
{
"cell_type": "markdown",
"id": "90160ab7-8e15-4306-9280-798d43165794",
"metadata": {},
"source": [
"## Initialization"
]
},
{
"cell_type": "code",
"execution_count": 223,
"id": "2e63b8db-31f2-48c0-9f18-8594e3f82d6e",
"metadata": {},
"outputs": [],
"source": [
"criterion = FocalLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 224,
"id": "d6b6b57c-68ac-40d9-91f3-a15e7e9f1af7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cpu')"
]
},
"execution_count": 224,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 225,
"id": "c8b82edb-389f-471d-a05a-c2ac3890acb8",
"metadata": {},
"outputs": [],
"source": [
"net = net.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "markdown",
"id": "b801f336-8c55-4217-8e67-221970c55f00",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 234,
"id": "9497a504-1e21-4c3f-8339-7ad1afc086f8",
"metadata": {},
"outputs": [],
"source": [
"def train(epoch):\n",
" net.train()\n",
" net.freeze_bn()\n",
" train_loss = 0\n",
" for batch in tqdm(train_loader):\n",
" imgs = batch['img'].to(device)\n",
" loc_targets = batch['loc_targets'].to(device)\n",
" cls_targets = batch['cls_targets'].to(device)\n",
" cls_targets = cls_targets.long()\n",
" \n",
" optimizer.zero_grad()\n",
" loc_pred, cls_pred = net(imgs)\n",
" loss = criterion(loc_pred, loc_targets, cls_pred, cls_targets)\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" train_loss += loss.item()\n",
" print('train_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(len(train_loader))))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 235,
"id": "c5bc39c8-431e-4fe1-ae38-03d484e7f952",
"metadata": {},
"outputs": [],
"source": [
"def test(epoch, loader):\n",
" with torch.no_grad():\n",
" net.eval()\n",
" test_loss = 0\n",
" for batch in tqdm(loader):\n",
" imgs = batch['img'].to(device)\n",
" loc_targets = batch['loc_targets'].to(device)\n",
" cls_targets = batch['cls_targets'].to(device)\n",
"\n",
" loc_pred, cls_pred = net(imgs)\n",
" loss = criterion(loc_pred, loc_targets, cls_pred, cls_targets)\n",
" test_loss += loss[0]\n",
" print('test_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(len(train_loader))))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 236,
"id": "503f99ad-fed4-48af-ae08-7a430b9c8ae5",
"metadata": {},
"outputs": [],
"source": [
"EPOCHS = 50"
]
},
{
"cell_type": "code",
"execution_count": 237,
"id": "c7534024-d9a2-430a-8ddc-a6db0ad0bd57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH 1\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ec194265c36b440383e074ffb997acce",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/56 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n",
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n",
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n",
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n",
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n",
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/var/folders/vr/x7p4fznn1dv39r_83dmyjkjm0000gn/T/ipykernel_42021/107800544.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"EPOCH {epoch}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/var/folders/vr/x7p4fznn1dv39r_83dmyjkjm0000gn/T/ipykernel_42021/3530921501.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc_targets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_targets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m inputs=inputs)\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for epoch in range(1, EPOCHS + 1):\n",
" print(f\"EPOCH {epoch}\")\n",
" ##\n",
" train(epoch)\n",
" test(epoch, val_loader)\n",
" ##"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79d32e91-d705-4711-8a71-79d94cd5f620",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7206910-c8cc-4909-84c3-76a7f3ee023e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "e33c98cd-3787-45d6-86eb-cf7a65aaf882",
"metadata": {},
"source": [
"## Testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2d3a81a-43d3-4da1-8f7c-db98adeff2d5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "8efa5f7e-a5c9-4942-bde2-7e0897f2e036",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "69670f59-7b7a-46a2-abf7-80180d8f994e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "afd407d5-6d0a-46bc-b769-7b7e604accfb",
"metadata": {},
"source": [
"## Saving Model Weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a8ef18e-aa6b-4cdb-8ef5-c91671b76e9e",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"retinanet.pt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90ff717d-7d5c-413d-9ae8-c57b935f26f9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
},
"toc-autonumbering": true
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment