Skip to content

Instantly share code, notes, and snippets.

@quocdat32461997
Last active October 29, 2022 02:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save quocdat32461997/cae85b748ce651ff6e3013880a5659af to your computer and use it in GitHub Desktop.
Save quocdat32461997/cae85b748ce651ff6e3013880a5659af to your computer and use it in GitHub Desktop.
Trivial example for Mixed-Gradient-Error and Mean-Gradient-Error
import tensorflow as tf
def MeanGradientError(outputs, targets, weight):
filter_x = tf.tile(tf.expand_dims(tf.constant([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype = outputs.dtype), axis = -1), [1, 1, outputs.shape[-1])
filter_x = tf.tile(tf.expand_dims(filter_x, axis = -1), [1, 1, 1, outputs.shape[-1]])
filter_y = tf.tile(tf.expand_dims(tf.constant([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype = outputs.dtype), axis = -1), [1, 1, targets.shape[-1]])
filter_y = tf.tile(tf.expand_dims(filter_y, axis = -1), [1, 1, 1, targets.shape[-1]])
# output gradient
output_gradient_x = tf.math.square(tf.nn.conv2d(outputs, filter_x, strides = 1, padding = 'SAME'))
output_gradient_y = tf.math.square(tf.nn.conv2d(outputs, filter_y, strides = 1, padding = 'SAME'))
#target gradient
target_gradient_x = tf.math.square(tf.nn.conv2d(targets, filter_x, strides = 1, padding = 'SAME'))
target_gradient_y = tf.math.square(tf.nn.conv2d(targets, filter_y, strides = 1, padding = 'SAME'))
# square
output_gradients = tf.math.sqrt(tf.math.add(output_gradient_x, output_gradient_y))
target_gradients = tf.math.sqrt(tf.math.add(target_gradient_x, target_gradient_y))
# compute mean gradient error
shape = output_gradients.shape[1:3]
mge = tf.math.reduce_sum(tf.math.squared_difference(output_gradients, target_gradients) / (shape[0] * shape[1]))
return mge * weight
x = tf.random.normal(shape = (224, 224, 3)
y = tf.random.normal(shape = (224, 224, 3)
gradient_loss = MeanGradientError(x, y, weight = 0.1)
@HydrogenSulfate
Copy link

I guess there is a small problem in your code at line 4, the sobel filter should be [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], but [[-1, -2, -2], [0, 0, 0], [1, 2, 1]] in your code

@Apprisco
Copy link

Apprisco commented Aug 2, 2022

Just letting people know: This loss tends to cause NAN losses in starting periods of training. It behaves once you have a epoch or two but I wouldn't recommend risking it. Theres hould be a better implementation out somewhere... except that the author's github is no longer visible. This is with the sobel filter fix of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment