Skip to content

Instantly share code, notes, and snippets.

@dhuynh95
Last active September 23, 2019 06:09
Show Gist options
  • Save dhuynh95/81b8142150aac1bf450e018ca38aee01 to your computer and use it in GitHub Desktop.
Save dhuynh95/81b8142150aac1bf450e018ca38aee01 to your computer and use it in GitHub Desktop.
Article 1 Snippet 2
train_size = 512
train_size = int(train_size * 1.25)
bs = 128
size = 28
data, valid_data = get_data(train_size,bs=bs,size=size)
# Architectural parameters of our model
conv = nn.Conv2d
act_fn = nn.ReLU
bn = nn.BatchNorm2d
rec_loss = "mse"
# Encoder architecture
enc_fn = create_encoder_denseblock
enc_args = {
"n_dense":3,
"c_start" :4
}
# Bottleneck architecture
bn_fn = VAEBottleneck
bn_args = {
"nfs":[128,14]
}
# Decoder architecture
dec_fn = create_decoder
dec_args = {
"nfs":[14,64,32,16,8,4,2,1],
"ks":[3,1,3,1,3,1],
"size": 28
}
# We create each part of the autoencoder
enc = enc_fn(**enc_args)
bn = bn_fn(**bn_args)
dec = dec_fn(**dec_args)
# We wrap the whole thing in a learner, and add a hook for the KL loss
learn = VisionAELearner(data,rec_loss,enc,bn,dec)
kl_hook = VAEHook(learn,beta=1)
# We add this code to plot the reconstructions
dec_modules = list(learn.dec[1].children())
learn.set_dec_modules(dec_modules)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment