Skip to content

Instantly share code, notes, and snippets.

@FrancescoSaverioZuppichini
Created September 23, 2018 16:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoSaverioZuppichini/0ba9c11ce6a44a183914386d9299e59c to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/0ba9c11ce6a44a183914386d9299e59c to your computer and use it in GitHub Desktop.
class MyEncoder(nn.Module):
def __init__(self, enc_sizes):
super().__init__()
self.conv_blokcs = nn.Sequential(*[conv_block(in_f, out_f, kernel_size=3, padding=1)
for in_f, out_f in zip(enc_sizes, enc_sizes[1:])])
def forward(self, x):
return self.conv_blokcs(x)
class MyDecoder(nn.Module):
def __init__(self, dec_sizes, n_classes):
super().__init__()
self.dec_blocks = nn.Sequential(*[dec_block(in_f, out_f)
for in_f, out_f in zip(dec_sizes, dec_sizes[1:])])
self.last = nn.Linear(dec_sizes[-1], n_classes)
def forward(self, x):
return self.dec_blocks()
class MyCNNClassifier(nn.Module):
def __init__(self, in_c, enc_sizes, dec_sizes, n_classes):
super().__init__()
self.enc_sizes = [in_c, *enc_sizes]
self.dec_sizes = [32 * 28 * 28, *dec_sizes]
self.encoder = MyEncoder(self.enc_sizes)
self.decoder = MyDecoder(dec_sizes, n_classes)
def forward(self, x):
x = self.encoder(x)
x = x.flatten(1) # flat
x = self.decoder(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment