Skip to content

Instantly share code, notes, and snippets.

@rsomani95
Last active July 20, 2022 06:08
Show Gist options
  • Save rsomani95/7922ea1e25fe68358be230d841d272a9 to your computer and use it in GitHub Desktop.
Save rsomani95/7922ea1e25fe68358be230d841d272a9 to your computer and use it in GitHub Desktop.
Example of how to load in a `timm` architecture with the YOLOX experiment setup. In this file, we're looking specifically at `ghostnet_100`, but this can be extended to any other architecture in `timm` that supports the `features_only` interface
import timm
import torch
import torch.distributed as dist
import torch.nn as nn
from upyog.imports import *
from yolox.exp.yolox_base import Exp as DefaultBaseExp
from yolox.models import YOLOPAFPN, YOLOX, YOLOXHead
from yolox.utils import get_local_rank, wait_for_the_master
__all__ = ["GhostNetAABaseCOCOExp"]
def freeze_layer(m: nn.Module):
for p in m.parameters():
p.requires_grad = False
class GhostNetAABaseCOCOExp(DefaultBaseExp):
def __init__(self):
super().__init__()
self.enable_mixup = False
self.freeze_backbone = False
self.freeze_fpn = False
def get_model(self):
def create_bbone():
return timm.create_model(
"ghostnet_100",
features_only=True,
out_indices=[2, 3, 4],
pretrained=True,
)
class TIMMWrap(nn.Module):
def __init__(self):
super().__init__()
self.m = create_bbone()
self.feature_names = ("dark3", "dark4", "dark5")
def forward(self, x):
out = self.m(x)
out = {k: v for k, v in zip(self.feature_names, out)}
return out
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
if getattr(self, "model", None) is None:
in_channels = create_bbone().feature_info.channels()
fpn = YOLOPAFPN(
depth=1,
width=1,
in_features=("dark3", "dark4", "dark5"),
in_channels=in_channels,
depthwise=True,
)
fpn.backbone = TIMMWrap()
head = YOLOXHead(
self.num_classes,
width=1,
in_channels=in_channels,
depthwise=True,
)
self.model = YOLOX(fpn, head)
self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
self.load_pretrained_model_()
self.freeze_model()
return self.model
def freeze_model(self):
from cinemanet.modelling.ghostnet import freeze_layer
if self.freeze_backbone:
freeze_layer(self.model.backbone.backbone)
logger.info(f"Froze backbone")
if self.freeze_fpn:
assert self.freeze_backbone, f"must freeze backbone if freezing FPN"
fpn_modules = [
"upsample",
"lateral_conv0",
"C3_p4",
"reduce_conv1",
"C3_p3",
"bu_conv2",
"C3_n3",
"bu_conv1",
"C3_n4",
]
for fpn_module in fpn_modules:
freeze_layer(getattr(self.model.backbone, fpn_module))
logger.info(f"Froze FPN")
def load_pretrained_model_(self):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment