Skip to content

Instantly share code, notes, and snippets.

@miki998
Created May 2, 2020 14:50
Show Gist options
  • Save miki998/2488d218dee278edb17be035004ba06a to your computer and use it in GitHub Desktop.
Save miki998/2488d218dee278edb17be035004ba06a to your computer and use it in GitHub Desktop.
class Yolov4Head(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')
self.conv2 = Conv_Bn_Activation(256, 255, 1, 1, 'linear', bn=False)
self.yolo1 = YoloLayer(anchor_mask=[0, 1, 2], num_classes=80,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9, stride=8)
# R -4
self.conv3 = Conv_Bn_Activation(128, 256, 3, 2, 'leaky')
# R -1 -16
self.conv4 = Conv_Bn_Activation(256, 256, 1, 1, 'leaky')
self.conv5 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
self.conv6 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
self.conv7 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
self.conv9 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
self.conv10 = Conv_Bn_Activation(512, 255, 1, 1, 'liner', bn=False)
self.yolo2 = YoloLayer(anchor_mask=[3, 4, 5], num_classes=80,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9, stride=16)
# R -4
self.conv11 = Conv_Bn_Activation(256, 512, 3, 2, 'leaky')
# R -1 -37
self.conv12 = Conv_Bn_Activation(512, 512, 1, 1, 'leaky')
self.conv13 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
self.conv14 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
self.conv15 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
self.conv16 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
self.conv17 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
self.conv18 = Conv_Bn_Activation(1024, 255, 1, 1, 'liner', bn=False)
self.yolo3 = YoloLayer(anchor_mask=[6, 7, 8], num_classes=80,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9, stride=32)
def forward(self, input1, input2, input3):
x1 = self.conv1(input1)
x2 = self.conv2(x1)
y1 = self.yolo1(x2)
x3 = self.conv3(input1)
# R -1 -16
x3 = x3 + input2
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
x8 = self.conv8(x7)
x9 = self.conv9(x8)
x10 = self.conv10(x9)
y2 = self.yolo2(x10)
# R -4
x11 = self.conv11(x8)
# R -1 -37
x11 = x11 + input3
x12 = self.conv12(x11)
x13 = self.conv13(x12)
x14 = self.conv14(x13)
x15 = self.conv15(x14)
x16 = self.conv16(x15)
x17 = self.conv17(x16)
x18 = self.conv18(x17)
y3 = self.yolo3(x18)
return [y1, y2, y3]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment