Skip to content

Instantly share code, notes, and snippets.

View mauro-belgiovine's full-sized avatar
:octocat:

Mauro Belgiovine mauro-belgiovine

:octocat:
View GitHub Profile
@mauro-belgiovine
mauro-belgiovine / tf_cond_example_gradients.py
Last active April 18, 2024 17:17
Example of tf.cond usage to allow gradient propagation through conditional branches
import tensorflow as tf
# X, Y and Z are all watched variables, although they will produce a gradient (i.e. not None) only if the branch executed has uses that variable
with tf.GradientTape() as tape:
x = tf.Variable(5.0, dtype=tf.float32, name='X')
y = tf.Variable(6.0, dtype=tf.float32, name='Y')
z = tf.Variable(8.0, dtype=tf.float32, name='Z')
cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),