Skip to content

Instantly share code, notes, and snippets.

@mauro-belgiovine
Last active April 18, 2024 17:17
Show Gist options
  • Save mauro-belgiovine/071dbc7dbfe663ce8e23f4e5b881df41 to your computer and use it in GitHub Desktop.
Save mauro-belgiovine/071dbc7dbfe663ce8e23f4e5b881df41 to your computer and use it in GitHub Desktop.
Example of tf.cond usage to allow gradient propagation through conditional branches
import tensorflow as tf
# X, Y and Z are all watched variables, although they will produce a gradient (i.e. not None) only if the branch executed has uses that variable
with tf.GradientTape() as tape:
x = tf.Variable(5.0, dtype=tf.float32, name='X')
y = tf.Variable(6.0, dtype=tf.float32, name='Y')
z = tf.Variable(8.0, dtype=tf.float32, name='Z')
cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),
false_fn = lambda: tf.square(y))
print(tape.watched_variables())
op = tape.gradient(cond, tape.watched_variables())
print(op)
# in this case, only X and Y are watched and X will have gradient=None
with tf.GradientTape() as tape:
x = tf.Variable(5.0, dtype=tf.float32, name='X')
y = tf.Variable(6.0, dtype=tf.float32, name='Y')
z = tf.Variable(8.0, dtype=tf.float32, name='Z')
cond = tf.cond(pred = x > y,
true_fn = lambda: tf.add(y, y),
false_fn = lambda: tf.square(y))
print(tape.watched_variables())
op = tape.gradient(cond, tape.watched_variables())
print(op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment