Skip to content

Instantly share code, notes, and snippets.

@RaphaelMeudec
Last active July 18, 2019 14:56
Show Gist options
  • Save RaphaelMeudec/f78c8d9ecb9c455e8f4141e84984de17 to your computer and use it in GitHub Desktop.
Save RaphaelMeudec/f78c8d9ecb9c455e8f4141e84984de17 to your computer and use it in GitHub Desktop.
How to register gradient with TF2
@tf.RegisterGradient("GuidedRelu")
def _GuidedReluGrad(op, grad):
gate_f = tf.cast(op.outputs[0] > 0, "float32") # Filter must be activated
gate_R = tf.cast(grad > 0, "float32") # Grads must be positive
return gate_f * gate_R * grad
with tf.Graph().as_default() as g:
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet', include_top=True)
with g.gradient_override_map({"Relu": "GuidedRelu"}):
# Do stuff here
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment