Skip to content

Instantly share code, notes, and snippets.

@insujeon
Last active August 11, 2017 07:53
Show Gist options
  • Save insujeon/0d30458845ce7000f44a3c9020361278 to your computer and use it in GitHub Desktop.
Save insujeon/0d30458845ce7000f44a3c9020361278 to your computer and use it in GitHub Desktop.
class Maxout(nn.Module):
def __init__(self, d_in, d_out, pool_size):
super().__init__()
self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size
self.lin = nn.Linear(d_in, d_out * pool_size)
def forward(self, inputs):
shape = list(inputs.size())
shape[-1] = self.d_out
shape.append(self.pool_size)
max_dim = len(shape) - 1
out = self.lin(inputs)
m, i = out.view(*shape).max(max_dim)
return m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment