Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matthewroche/a0abcb49c6c7ad123ad14e2dfa687d99 to your computer and use it in GitHub Desktop.
Save matthewroche/a0abcb49c6c7ad123ad14e2dfa687d99 to your computer and use it in GitHub Desktop.
Pooling Linear Classifier inc Softmax
# Create custom classifier
class PoolingLinearClassifierSoftmax(nn.Module):
def __init__(self, layers, drops):
super().__init__()
self.layers = nn.ModuleList([
LinearBlock(layers[i], layers[i + 1], drops[i]) for i in range(len(layers) - 1)])
def pool(self, x, bs, is_max):
f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
return f(x.permute(1,2,0), (1,)).view(bs,-1)
def forward(self, input):
raw_outputs, outputs = input
output = outputs[-1]
sl,bs,_ = output.size()
avgpool = self.pool(output, bs, False)
mxpool = self.pool(output, bs, True)
x = torch.cat([output[-1], mxpool, avgpool], 1)
for l in self.layers:
l_x = l(x)
x = F.relu(l_x)
l_x = F.softmax(l_x)
return l_x, raw_outputs, outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment