Skip to content

Instantly share code, notes, and snippets.

@gangmul12
Created November 27, 2020 13:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gangmul12/bff5f456af9fa756165d3bb33d6cd5c9 to your computer and use it in GitHub Desktop.
Save gangmul12/bff5f456af9fa756165d3bb33d6cd5c9 to your computer and use it in GitHub Desktop.
tf 2.0 force gradient
import tensorflow as tf
import numpy as np
@tf.custom_gradient
def force_grad_layer(x, grad_tensor):
def grad(dy):
return grad_tensor, tf.constant(0)
return tf.identity(x), grad
with tf.GradientTape() as tape:
forced_grad = tf.constant(wanted_gradient_value)
"""variable description"""
some_variable = tf.Variable(blahblah)
"""varibale description"""
y = func_that_need_forced_gradient_value(last_layer)
yy = force_grad_layer(y, forced_grad)
dyy_over_ds_with_dL_over_dyy_equals_wanted_gradient_value = tape.gradient(yy, some_variable)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment