Last active
April 3, 2024 00:54
-
-
Save leslie-fang-intel/5489e3e55bff82212ba7e70cf4becb8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def _unsqueeze_multiple(x, dimensions): | |
for dim in sorted(dimensions): | |
x = torch.unsqueeze(x, dim) | |
return x | |
if __name__ == "__main__": | |
input = torch.randn(2, 3, 4, 4) | |
scales = torch.tensor([3, 3, 3]) | |
axis = 1 | |
print(input.size(), flush=True) | |
print(scales.size(), flush=True) | |
# r = input * scales # Failed to be mul | |
broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim)) | |
u_scales = _unsqueeze_multiple(scales, broadcast_dims) # size: [1,3,1,1] | |
print(u_scales.size(), flush=True) | |
r = input * u_scales # can be mul |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment