Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created November 9, 2020 11:35
Show Gist options
  • Save seanbenhur/9573e3c0da6e87b61d82ce167e1a7e6c to your computer and use it in GitHub Desktop.
Save seanbenhur/9573e3c0da6e87b61d82ce167e1a7e6c to your computer and use it in GitHub Desktop.
class Inception_block(nn.Module):
def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool):
super(Inception_block, self).__init__()
self.branch1 = conv_block(in_channels, out_1x1, kernel_size=(1,1))
self.branch2 = nn.Sequential(
conv_block(in_channels, red_3x3, kernel_size=(1,1)),
conv_block(red_3x3, out_3x3, kernel_size=(3,3),padding=(1,1))
)
self.branch3 = nn.Sequential(
conv_block(in_channels, red_5x5, kernel_size=(1,1)),
conv_block(red_5x5, out_5x5, kernel_size=(5,5),padding=(2,2))
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=(3,3),stride=(1,1),padding=(1,1)),
conv_block(in_channels,out_1x1pool,kernel_size=(1,1))
)
def forward(self, x):
return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment