Skip to content

Instantly share code, notes, and snippets.

@FrancescoSaverioZuppichini
Created September 23, 2018 16:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoSaverioZuppichini/45d449493b04dffae64f334965824d94 to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/45d449493b04dffae64f334965824d94 to your computer and use it in GitHub Desktop.
class MyCNNClassifier(nn.Module):
def __init__(self, in_c, enc_sizes, n_classes):
super().__init__()
self.enc_sizes = [in_c, *enc_sizes]
conv_blokcs = [conv_block(in_f, out_f, kernel_size=3, padding=1)
for in_f, out_f in zip(self.enc_sizes, self.enc_sizes[1:])]
self.encoder = nn.Sequential(*conv_blokcs)
self.decoder = nn.Sequential(
nn.Linear(32 * 28 * 28, 1024),
nn.Sigmoid(),
nn.Linear(1024, n_classes)
)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1) # flat
x = self.decoder(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment