Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created November 9, 2020 11:29
Show Gist options
  • Save seanbenhur/51c8c4665a2efeacd765a43b77742cb1 to your computer and use it in GitHub Desktop.
Save seanbenhur/51c8c4665a2efeacd765a43b77742cb1 to your computer and use it in GitHub Desktop.
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes):
super(InceptionAux,self).__init__()
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.7)
self.pool = nn.AvgPool2d(kernel_size=5,stride=3)
self.conv = conv_block(in_channels, 128, kernel_size=1)
self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, num_classes)
def forward(self, x):
x = self.pool(x)
x = self.conv(x)
x = x.reshape(x.shape[0], -1)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment