Skip to content

Instantly share code, notes, and snippets.

@Sivaram46
Created June 12, 2021 17:35

Revisions

  1. Sivaram46 created this gist Jun 12, 2021.
    18 changes: 18 additions & 0 deletions dense_layer.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,18 @@
    class Dense_Layer(nn.Module):
    def __init__(self, in_channels, growthrate, bn_size):
    super(Dense_Layer, self).__init__()

    self.bn1 = nn.BatchNorm2d(in_channels)
    self.conv1 = nn.Conv2d(
    in_channels, bn_size * growthrate, kernel_size=1, bias=False
    )
    self.bn2 = nn.BatchNorm2d(bn_size * growthrate)
    self.conv2 = nn.Conv2d(
    bn_size * growthrate, growthrate, kernel_size=3, padding=1, bias=False
    )

    def forward(self, prev_features):
    out1 = torch.cat(prev_features, dim=1)
    out1 = self.conv1(F.relu(self.bn1(out1)))
    out2 = self.conv2(F.relu(self.bn2(out1)))
    return out2