Last active
March 2, 2019 23:36
-
-
Save irhum/11009f23950e31c6612402b80fae6596 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# build your encoder upto here. It can simply be a series of dense layers, a convolutional network | |
# or even an LSTM decoder. Once made, flatten out the final layer of the encoder, call it hidden. | |
# we use Keras to build the graph | |
latent_size = 5 | |
mean = Dense(latent_size)(hidden) | |
# we usually don't directly compute the stddev σ | |
# but the log of the stddev instead, which is log(σ) | |
# the reasoning is similar to why we use softmax, instead of directly outputting | |
# numbers in fixed range [0, 1], the network can output a wider range of numbers which we can later compress down | |
log_stddev = Dense(latent_size)(hidden) | |
def sampler(mean, log_stddev): | |
# we sample from the standard normal a matrix of batch_size * latent_size (taking into account minibatches) | |
std_norm = K.random_normal(shape=(K.shape(mean)[0], latent_size), mean=0, stddev=1) | |
# sampling from Z~N(μ, σ^2) is the same as sampling from μ + σX, X~N(0,1) | |
return mean + K.exp(log_stddev) * std_norm | |
latent_vector = Lambda(sampler)([mean, log_stddev]) | |
# pass latent_vector as input to decoder layers |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment