Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Last active April 3, 2024 00:54
Show Gist options
  • Save leslie-fang-intel/5489e3e55bff82212ba7e70cf4becb8d to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/5489e3e55bff82212ba7e70cf4becb8d to your computer and use it in GitHub Desktop.
import torch
def _unsqueeze_multiple(x, dimensions):
for dim in sorted(dimensions):
x = torch.unsqueeze(x, dim)
return x
if __name__ == "__main__":
input = torch.randn(2, 3, 4, 4)
scales = torch.tensor([3, 3, 3])
axis = 1
print(input.size(), flush=True)
print(scales.size(), flush=True)
# r = input * scales # Failed to be mul
broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
u_scales = _unsqueeze_multiple(scales, broadcast_dims) # size: [1,3,1,1]
print(u_scales.size(), flush=True)
r = input * u_scales # can be mul
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment