Skip to content

Instantly share code, notes, and snippets.

@alkalait
Created May 17, 2021 15:08
Show Gist options
  • Save alkalait/572ceca0ea87b942f662b620dc9ca50c to your computer and use it in GitHub Desktop.
Save alkalait/572ceca0ea87b942f662b620dc9ca50c to your computer and use it in GitHub Desktop.
Are the shapes of two tensors compatible for broadcasting?
from torch import Tensor
def is_broadcastable(x: Tensor, y: Tensor) -> bool:
""" Are the shapes of two tensors compatible for broadcasting? """
if not x.ndim == y.ndim:
return False
n_same_dim = (torch.as_tensor(x.shape) == torch.as_tensor(y.shape)).sum()
return (int(n_same_dim) - x.ndim) <= 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment