Skip to content

Instantly share code, notes, and snippets.

@cshimmin
Created February 6, 2020 19:54
Show Gist options
  • Save cshimmin/0e5ea2d574720b96497c1c93ff9c09a0 to your computer and use it in GitHub Desktop.
Save cshimmin/0e5ea2d574720b96497c1c93ff9c09a0 to your computer and use it in GitHub Desktop.
pseudo-keras code for masked MSE loss
encoder_input = layers.Input( input_shape )
# ... define encoder sub-network
encoder = keras.models.Model(encoder_input, ...)
decoder_input = layers.Input( latent_shape )
# ... define decoder sub-network
x = next_to_last_decoder_layer
decoder_output_cells = layers.Dense( input_shape, activation='relu')(x)
decoder_output_masks = layers.Dense( input_shape, activation='sigmoid')(x)
decoder = keras.models.Model(decoder_input, [decoder_output_cells, decoder_output_masks])
ae_input = layers.Input( input_shape )
ae_output_cells, ae_output_masks = decoder(encoder(ae_input))
# define loss in terms of the input tensor
def ae_loss(y1,y2):
# Use crossentropy to learn binary masks derived from some energy threshold
# this could be a constant (e.g. 0) or computed as some fraction of the total energy per-event
mask_true = K.cast(ae_input > THRESHOLD, 'float32')
mask_loss = K.sum(K.binary_crossentropy(mask_true, sparse_mask_output), axis=-1)
n_active_cells = K.sum(mask_true, axis=-1) # or sum along whatever relevant axes
masked_mse_loss = K.sum(K.square(mask_true*(output_cells - output_mask)), axis=-1)/(n_active_cells + EPSILON)
# return a weighted combination of loss terms (alpha is a hyperparameter t.b.d.)
return K.mean(masked_mse_loss + alpha * mask_loss)
autoencoder = keras.models.Model(ae_input, [ae_output_cells, ae_output_masks])
autoencoder.compile(loss=ae_loss, optimizer=...)
# train AE
autoencoder.fit(x=x_train, y=None, ...)
# to generate, sample both the cell and masks values and use a threshold on the mask:
pred_cells, pred_masks = decoder.predict(np.random.normal(...))
samples = (pred_masks>0.5)*pred_cells
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment