Skip to content

Instantly share code, notes, and snippets.

View mauro-belgiovine's full-sized avatar

Mauro Belgiovine mauro-belgiovine

View GitHub Profile
mauro-belgiovine /
Last active April 18, 2024 17:17
Example of tf.cond usage to allow gradient propagation through conditional branches
import tensorflow as tf
# X, Y and Z are all watched variables, although they will produce a gradient (i.e. not None) only if the branch executed has uses that variable
with tf.GradientTape() as tape:
x = tf.Variable(5.0, dtype=tf.float32, name='X')
y = tf.Variable(6.0, dtype=tf.float32, name='Y')
z = tf.Variable(8.0, dtype=tf.float32, name='Z')
cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),