Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/58fd829bb48969f31b7694272daa89a0 to your computer and use it in GitHub Desktop.
Save SannaPersson/58fd829bb48969f31b7694272daa89a0 to your computer and use it in GitHub Desktop.
class ScalePrediction(nn.Module):
def __init__(self, in_channels, num_classes, anchors_per_scale):
super(ScalePrediction, self).__init__()
self.pred = nn.Sequential(
CNNBlock(in_channels, 2*in_channels, kernel_size=3, padding=1),
CNNBlock(2*in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),
)
self.num_classes = num_classes
self.anchors_per_scale = anchors_per_scale
def forward(self, x):
return (
self.pred(x)
.reshape(x.shape[0], self.anchors_per_scale, self.num_classes + 5, x.shape[2], x.shape[3])
.permute(0, 1, 3, 4, 2)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment