-
-
Save Sivaram46/d14b9054e06396d0ee26891802d9f4fb to your computer and use it in GitHub Desktop.
Residual-block
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import required libraries | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision | |
# basic resdidual block of ResNet | |
# This is generic in the sense, it could be used for downsampling of features. | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=[1, 1], downsample=None): | |
""" | |
A basic residual block of ResNet | |
Parameters | |
---------- | |
in_channels: Number of channels that the input have | |
out_channels: Number of channels that the output have | |
stride: strides in convolutional layers | |
downsample: A callable to be applied before addition of residual mapping | |
""" | |
super(ResidualBlock, self).__init__() | |
self.conv1 = nn.Conv2d( | |
in_channels, out_channels, kernel_size=3, stride=stride[0], | |
padding=1, bias=False | |
) | |
self.conv2 = nn.Conv2d( | |
out_channels, out_channels, kernel_size=3, stride=stride[1], | |
padding=1, bias=False | |
) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
# applying a downsample function before adding it to the output | |
if(self.downsample is not None): | |
residual = downsample(residual) | |
out = F.relu(self.bn(self.conv1(x))) | |
out = self.bn(self.conv2(out)) | |
# note that adding residual before activation | |
out = out + residual | |
out = F.relu(out) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment