Skip to content

Instantly share code, notes, and snippets.

@Sivaram46
Created June 12, 2021 17:37
Show Gist options
  • Save Sivaram46/02c251774587ba0f3778a99ca0c7ac7b to your computer and use it in GitHub Desktop.
Save Sivaram46/02c251774587ba0f3778a99ca0c7ac7b to your computer and use it in GitHub Desktop.
class Dense_Block(nn.ModuleDict):
def __init__(self, n_layers, in_channels, growthrate, bn_size):
"""
A Dense block consists of `n_layers` of `Dense_Layer`
Parameters
----------
n_layers: Number of dense layers to be stacked
in_channels: Number of input channels for first layer in the block
growthrate: Growth rate (k) as mentioned in DenseNet paper
bn_size: Multiplicative factor for # of bottleneck layers
"""
super(Dense_Block, self).__init__()
layers = dict()
for i in range(n_layers):
layer = Dense_Layer(in_channels + i * growthrate, growthrate, bn_size)
layers['dense{}'.format(i)] = layer
self.block = nn.ModuleDict(layers)
def forward(self, features):
if(isinstance(features, torch.Tensor)):
features = [features]
for _, layer in self.block.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment