Skip to content

Instantly share code, notes, and snippets.

@refraction-ray
Created July 11, 2020 07:11
Show Gist options
  • Save refraction-ray/033804b1144ecc847e5970ae508de40c to your computer and use it in GitHub Desktop.
Save refraction-ray/033804b1144ecc847e5970ae508de40c to your computer and use it in GitHub Desktop.
tensorflow gradienttape with ragged tensor
import tensorflow as tf
## note tape.watch() cannot accept ragged tensor as input
rt = tf.ragged.constant([[1.], [2., 3.]])
# rt = rt.to_tensor()
with tf.GradientTape() as t:
t.watch(rt.values)
loss = 2*rt[0,0]+rt[1,1]**3
g_flatten = t.gradient(loss, rt.values)
g = rt.with_values(g_flatten)
print(g)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment