Skip to content

Instantly share code, notes, and snippets.

@Sivaram46
Last active June 12, 2021 17:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sivaram46/d14b9054e06396d0ee26891802d9f4fb to your computer and use it in GitHub Desktop.
Save Sivaram46/d14b9054e06396d0ee26891802d9f4fb to your computer and use it in GitHub Desktop.
Residual-block
# import required libraries
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
# basic resdidual block of ResNet
# This is generic in the sense, it could be used for downsampling of features.
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=[1, 1], downsample=None):
"""
A basic residual block of ResNet
Parameters
----------
in_channels: Number of channels that the input have
out_channels: Number of channels that the output have
stride: strides in convolutional layers
downsample: A callable to be applied before addition of residual mapping
"""
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride[0],
padding=1, bias=False
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=stride[1],
padding=1, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
# applying a downsample function before adding it to the output
if(self.downsample is not None):
residual = downsample(residual)
out = F.relu(self.bn(self.conv1(x)))
out = self.bn(self.conv2(out))
# note that adding residual before activation
out = out + residual
out = F.relu(out)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment