Skip to content

Instantly share code, notes, and snippets.

@rwightman
Last active January 27, 2023 06:09
Show Gist options
  • Save rwightman/fc82d8b7219a60c769b1e7b9bd9442d6 to your computer and use it in GitHub Desktop.
Save rwightman/fc82d8b7219a60c769b1e7b9bd9442d6 to your computer and use it in GitHub Desktop.
Use effdet BiFPN standalone
from typing import Callable, Union
from dataclasses import dataclass
import timm
import torch.nn as nn
from effdet.efficientdet import BiFpn
from effdet.config import fpn_config
from omegaconf import DictConfig
@dataclass
class StandaloneConfig:
min_level: int = 3
max_level: int = 7
num_levels: int = max_level - min_level + 1
pad_type: str = '' # use 'same' for TF style SAME padding
act_type: str = 'silu'
norm_layer: Callable = None # defaults to batch norm when None
norm_kwargs = dict(eps=.001, momentum=.01)
separable_conv: bool = True
apply_resample_bn: bool = True
conv_after_downsample: bool = False
conv_bn_relu_pattern: bool = False
use_native_resize_op: bool = False
downsample_type: bool = 'bilinear'
upsample_type: bool = 'bilinear'
redundant_bias: bool = False
fpn_cell_repeats: int = 3
fpn_channels: int = 88
fpn_name: str = 'bifpn_fa'
fpn_config: DictConfig = None # determines FPN connectivity, if None, use default for type (name)
def __post_init__(self):
self.num_levels = self.max_level - self.min_level + 1
class ExampleNet(nn.Module):
def __init__(self, config, backbone='resnet50', backbone_indices=(2, 3, 4)):
super().__init__()
self.backbone = timm.create_model(backbone, features_only=True, out_indices=backbone_indices, pretrained=True)
self.bifpn = BiFpn(config, self.backbone.feature_info.get_dicts())
def forward(self, x):
x = self.backbone(x)
x = self.bifpn(x)
return x

Setup

# in torch env
pip install timm
pip install effdet

BIFPN, 5 levels, 3 repeats

sc = StandaloneConfig()  # creates a bifpn layout with fast attn
e = ExampleNet(sc)
o = e(torch.randn(2, 3, 512, 512))
for x in o:
    print(x.shape)

>>> torch.Size([2, 88, 64, 64])
>>> torch.Size([2, 88, 32, 32])
>>> torch.Size([2, 88, 16, 16])
>>> torch.Size([2, 88, 8, 8])
>>> torch.Size([2, 88, 4, 4])

QuadFPN, 3 levels, 4 repeats

sc = StandaloneConfig(fpn_name='qufpn_fa', fpn_cell_repeats=4, fpn_channels=128, min_level=5, max_level=7)
e = ExampleNet(sc)
o = e(torch.randn(2, 3, 512, 512))
for x in o:
    print(x.shape)
    
>>> torch.Size([2, 128, 16, 16])
>>> torch.Size([2, 128, 8, 8])
>>> torch.Size([2, 128, 4, 4])

PAN, 4 levels, 5 repeats

sc = StandaloneConfig(fpn_name='pan_fa', fpn_cell_repeats=5, fpn_channels=128, min_level=4, max_level=7)
e = ExampleNet(sc)
o = e(torch.randn(2, 3, 512, 512))
for x in o:
    print(x.shape)

torch.Size([2, 128, 32, 32])
torch.Size([2, 128, 16, 16])
torch.Size([2, 128, 8, 8])
torch.Size([2, 128, 4, 4])

Diagram

FPN vs PAN vs BiFPN from EfficentDet Paper: https://arxiv.org/abs/1911.09070

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment