-
-
Save SannaPersson/5fc1043764ef6e1d6a99de7f068e48df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class YOLOv3(nn.Module): | |
def __init__(self, in_channels=3, num_classes=80): | |
super(YOLOv3, self).__init__() | |
self.num_classes = num_classes | |
self.in_channels = in_channels | |
self.layers = self._create_conv_layers() | |
def forward(self, x): | |
outputs = [] | |
route_connections = [] | |
for layer in self.layers: | |
if isinstance(layer, ScalePrediction): | |
outputs.append(layer(x)) | |
continue | |
x = layer(x) | |
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: | |
route_connections.append(x) | |
elif isinstance(layer, nn.Upsample): | |
x = torch.cat([x, route_connections[-1]], dim=1) | |
route_connections.pop() | |
return outputs | |
def _create_conv_layers(self): | |
layers = nn.ModuleList() | |
in_channels = self.in_channels | |
for module in config: | |
if isinstance(module, tuple): | |
out_channels, kernel_size, stride = module | |
layers.append( | |
CNNBlock( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=1 if kernel_size == 3 else 0, | |
) | |
) | |
in_channels = out_channels | |
elif isinstance(module, list): | |
num_repeats = module[1] | |
layers.append( | |
ResidualBlock( | |
in_channels, | |
num_repeats=num_repeats, | |
) | |
) | |
elif isinstance(module, str): | |
if module == "S": | |
layers += [ | |
ResidualBlock(in_channels, use_residual=False, num_repeats=1), | |
CNNBlock(in_channels, in_channels // 2, kernel_size=1), | |
ScalePrediction(in_channels // 2, num_classes=self.num_classes), | |
] | |
in_channels = in_channels // 2 | |
elif module == "U": | |
layers.append( | |
nn.Upsample(scale_factor=2), | |
) | |
in_channels = in_channels * 3 | |
return layers |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment