Created
February 25, 2016 05:45
-
-
Save ericjang/b01328d4e67d3c283c12 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def binary_crossentropy(t,o): | |
return -(t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps)) | |
# reconstruction term appears to have been collapsed down to a single scalar value (rather than one per item in minibatch) | |
x_recons=tf.nn.sigmoid(cs[-1]) | |
# after computing binary cross entropy, sum across features then take the mean of those sums across minibatches | |
Lx=tf.reduce_sum(binary_crossentropy(x,x_recons),1) # reconstruction term | |
Lx=tf.reduce_mean(Lx) | |
kl_terms=[0]*T | |
for t in range(T): | |
mu2=tf.square(mus[t]) | |
sigma2=tf.square(sigmas[t]) | |
logsigma=logsigmas[t] | |
kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch) | |
KL=tf.add_n(kl_terms) # this is 1xminibatch, corresponding to summing kl_terms from 1:T | |
Lz=tf.reduce_mean(KL) # average over minibatches | |
cost=Lx+Lz |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment