Skip to content

Instantly share code, notes, and snippets.

@lukasgabriel
Last active December 1, 2022 18:14
Show Gist options
  • Save lukasgabriel/dc19d203401d06a1fa3c763e839defe0 to your computer and use it in GitHub Desktop.
Save lukasgabriel/dc19d203401d06a1fa3c763e839defe0 to your computer and use it in GitHub Desktop.
katy
def correlation_coefficient(y_true, y_pred):
'''
This function is for singleduration model.
'''
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_pred /= max_y_pred
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_true /= (sum_y_true + K.epsilon())
y_pred /= (sum_y_pred + K.epsilon())
N = shape_r_out * shape_c_out
sum_xy = K.sum(K.sum(y_true * y_pred, axis=1), axis=1) / N
sum_x = K.sum(K.sum(y_true, axis=1), axis=1)
sum_y = K.sum(K.sum(y_pred, axis=1), axis=1)
sum_x2 = K.sum(K.sum(K.square(y_true), axis=1), axis=1)
sum_y2 = K.sum(K.sum(K.square(y_pred), axis=1), axis=1)
num = sum_xy - sum_x * sum_y
den = K.sqrt(sum_x2 - K.square(sum_x)) * K.sqrt(sum_y2 - K.square(sum_y))
return num / den
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment