Skip to content

Instantly share code, notes, and snippets.

@ali-mosavian
Created August 15, 2019 19:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ali-mosavian/230a78d7bf4949e87b86e63a0e82045d to your computer and use it in GitHub Desktop.
Save ali-mosavian/230a78d7bf4949e87b86e63a0e82045d to your computer and use it in GitHub Desktop.
from keras import backend as K
def masked_categorical_crossentropy(y_true, y_pred):
mask = K.cast(K.not_equal(y_true, -1), K.floatx())
return K.categorical_crossentropy(y_true * mask, y_pred * mask)
def masked_categorical_accuracy(y_true, y_pred):
mask = K.cast(K.not_equal(y_true, -1), K.floatx())
return metrics.categorical_accuracy(y_true * mask, y_pred * mask)
model = build_model()
model = ModelMGPU(model, 2)
model.compile(
loss=masked_categorical_crossentropy,
optimizer=optimizers.Nadam(0.01),
metrics=[masked_categorical_accuracy]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment