Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created November 9, 2020 12:23
Show Gist options
  • Save seanbenhur/379db0160c201079bf60377cdab0ba94 to your computer and use it in GitHub Desktop.
Save seanbenhur/379db0160c201079bf60377cdab0ba94 to your computer and use it in GitHub Desktop.
# an essential block of layers which forms resnets
class ResBlock(nn.Module):
#in_channels -> input channels,int_channels->intermediate channels
def __init__(self,in_channels,int_channels,identity_downsample=None,stride=1):
super(ResBlock,self).__init__()
self.expansion = 4
self.conv1 = nn.Conv2d(in_channels,int_channels,kernel_size=1,stride=1,padding=0)
self.bn1 = nn.BatchNorm2d(int_channels)
self.conv2 = nn.Conv2d(int_channels,int_channels,kernel_size=3,stride=stride,padding=1)
self.bn2 = nn.BatchNorm2d(int_channels)
self.conv3 = nn.Conv2d(int_channels,int_channels*self.expansion,kernel_size=1,stride=1,padding=0)
self.bn3 = nn.BatchNorm2d(int_channels*self.expansion)
self.relu = nn.ReLU()
self.identity_downsample = identity_downsample
self.stride = stride
def forward(self,x):
identity = x.clone()
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
#the so called skip connections
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment