Skip to content

Instantly share code, notes, and snippets.

@YiqinZhao
Created March 1, 2019 07:02
Show Gist options
  • Save YiqinZhao/766305acfb0141b70370e5dcd9415eb6 to your computer and use it in GitHub Desktop.
Save YiqinZhao/766305acfb0141b70370e5dcd9415eb6 to your computer and use it in GitHub Desktop.
Keras multi-class average recall (aka. unweighted accuracy) metric.
class MulticlassAverageRecall(Layer):
def __init__(self, name='multiclass_recall', classes=4,
output_idx=0, **kwargs):
super(MulticlassAverageRecall, self).__init__(name=name, **kwargs)
self.stateful = True
self.classes = classes
self.output_idx = output_idx
self.t = K.variable(np.zeros(classes), dtype='int32')
self.p = K.variable(np.zeros(classes), dtype='int32')
def reset_states(self):
K.set_value(self.t, np.zeros(self.classes))
K.set_value(self.p, np.zeros(self.classes))
def __call__(self, y_true, y_pred):
# Init a bias matrix
b = K.variable([1 / (v + 1) for v in range(self.classes)],
dtype=K.floatx())
# Simulate to_categorical opeation
x, y = K.argmax(y_pred, axis=-1), K.argmax(y_true, axis=-1)
x, y = K.cast(x, K.floatx()), K.cast(y, K.floatx())
x, y = K.expand_dims(x, axis=-1), K.expand_dims(y, axis=-1)
x, y = (x + 1) * b - 1, (y + 1) * b - 1
# Make correct position filled with 1
x, y = K.cast(x, 'bool'), K.cast(y, 'bool')
x, y = 1 - K.cast(x, 'int32'), 1 - K.cast(y, 'int32')
x, y = K.transpose(x), K.transpose(y)
t = K.sum(y, axis=-1)
p = K.sum(x * y, axis=-1)
current_t = self.t * 1
current_p = self.p * 1
self.add_update(K.update_add(self.t, t))
self.add_update(K.update_add(self.p, p))
return K.mean((current_p + p) / (current_t + t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment