Skip to content

Instantly share code, notes, and snippets.

@ritwikraha
Created March 24, 2022 15:57
Show Gist options
  • Save ritwikraha/cf0eaea52f0b87793cf7831f300b891e to your computer and use it in GitHub Desktop.
Save ritwikraha/cf0eaea52f0b87793cf7831f300b891e to your computer and use it in GitHub Desktop.
The shape_list function from HuggingFace/transformers
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns:
`List[int]`: The shape of the tensor as a list.
"""
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment