Created
February 6, 2020 19:54
-
-
Save cshimmin/0e5ea2d574720b96497c1c93ff9c09a0 to your computer and use it in GitHub Desktop.
pseudo-keras code for masked MSE loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
encoder_input = layers.Input( input_shape ) | |
# ... define encoder sub-network | |
encoder = keras.models.Model(encoder_input, ...) | |
decoder_input = layers.Input( latent_shape ) | |
# ... define decoder sub-network | |
x = next_to_last_decoder_layer | |
decoder_output_cells = layers.Dense( input_shape, activation='relu')(x) | |
decoder_output_masks = layers.Dense( input_shape, activation='sigmoid')(x) | |
decoder = keras.models.Model(decoder_input, [decoder_output_cells, decoder_output_masks]) | |
ae_input = layers.Input( input_shape ) | |
ae_output_cells, ae_output_masks = decoder(encoder(ae_input)) | |
# define loss in terms of the input tensor | |
def ae_loss(y1,y2): | |
# Use crossentropy to learn binary masks derived from some energy threshold | |
# this could be a constant (e.g. 0) or computed as some fraction of the total energy per-event | |
mask_true = K.cast(ae_input > THRESHOLD, 'float32') | |
mask_loss = K.sum(K.binary_crossentropy(mask_true, sparse_mask_output), axis=-1) | |
n_active_cells = K.sum(mask_true, axis=-1) # or sum along whatever relevant axes | |
masked_mse_loss = K.sum(K.square(mask_true*(output_cells - output_mask)), axis=-1)/(n_active_cells + EPSILON) | |
# return a weighted combination of loss terms (alpha is a hyperparameter t.b.d.) | |
return K.mean(masked_mse_loss + alpha * mask_loss) | |
autoencoder = keras.models.Model(ae_input, [ae_output_cells, ae_output_masks]) | |
autoencoder.compile(loss=ae_loss, optimizer=...) | |
# train AE | |
autoencoder.fit(x=x_train, y=None, ...) | |
# to generate, sample both the cell and masks values and use a threshold on the mask: | |
pred_cells, pred_masks = decoder.predict(np.random.normal(...)) | |
samples = (pred_masks>0.5)*pred_cells |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment