Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/5fc1043764ef6e1d6a99de7f068e48df to your computer and use it in GitHub Desktop.
Save SannaPersson/5fc1043764ef6e1d6a99de7f068e48df to your computer and use it in GitHub Desktop.
class YOLOv3(nn.Module):
def __init__(self, in_channels=3, num_classes=80):
super(YOLOv3, self).__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.layers = self._create_conv_layers()
def forward(self, x):
outputs = []
route_connections = []
for layer in self.layers:
if isinstance(layer, ScalePrediction):
outputs.append(layer(x))
continue
x = layer(x)
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
route_connections.append(x)
elif isinstance(layer, nn.Upsample):
x = torch.cat([x, route_connections[-1]], dim=1)
route_connections.pop()
return outputs
def _create_conv_layers(self):
layers = nn.ModuleList()
in_channels = self.in_channels
for module in config:
if isinstance(module, tuple):
out_channels, kernel_size, stride = module
layers.append(
CNNBlock(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1 if kernel_size == 3 else 0,
)
)
in_channels = out_channels
elif isinstance(module, list):
num_repeats = module[1]
layers.append(
ResidualBlock(
in_channels,
num_repeats=num_repeats,
)
)
elif isinstance(module, str):
if module == "S":
layers += [
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
ScalePrediction(in_channels // 2, num_classes=self.num_classes),
]
in_channels = in_channels // 2
elif module == "U":
layers.append(
nn.Upsample(scale_factor=2),
)
in_channels = in_channels * 3
return layers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment