Skip to content

Instantly share code, notes, and snippets.

@davidhughhenrymack
Created July 7, 2018 15:46
Show Gist options
  • Save davidhughhenrymack/ad9319d23276ffa608f6826e820c7d2c to your computer and use it in GitHub Desktop.
Save davidhughhenrymack/ad9319d23276ffa608f6826e820c7d2c to your computer and use it in GitHub Desktop.
def dynamic_assert_shape(tensor, shape):
"""
Check that a tensor has a shape given by a list of constants and tensor values.
This function will place an operation into your graph that gets executed at runtime.
This is helpful because often tensors have many dynamic sized dimensions that
you cannot otherwise compare / assert are as you expect.
For example, measure a dimension at run time:
`batch_size = tf.shape(my_tensor)[0]`
then assert another tensor does indeed have the right shape:
`other_tensor = dynamic_assert_shape(other_tensor, [batch_size, 16])`
You should use this as an inline identity function so that the operation it generates
gets added and executed in the graph
Returns: the argument `tensor` unchanged
"""
lhs = tf.shape(tensor)
rhs = tf.convert_to_tensor(shape, dtype=lhs.dtype)
assert_op = tf.assert_equal(lhs, rhs, message=f"Asserting shape of {tensor.name}")
with tf.control_dependencies([assert_op]):
return tf.identity(tensor, name="dynamic_assert_shape")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment