Skip to content

Instantly share code, notes, and snippets.

@goddoe
Created July 30, 2019 09:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goddoe/381a1cd1bdd66bc541aeda2d89f51bc2 to your computer and use it in GitHub Desktop.
Save goddoe/381a1cd1bdd66bc541aeda2d89f51bc2 to your computer and use it in GitHub Desktop.
repeat vs expand pytorch
import torch
A = torch.randn([12, 9, 64])
B = torch.randn([12, 9, 64])
Ar = A.repeat(1, 1, 9).view(12, 81, 64)
Br = B.repeat(1, 9, 1)
C = torch.cat((Ar, Br), dim=2)
D = torch.cat([A.unsqueeze(2).expand(-1, -1, 9, -1),
B.unsqueeze(1).expand(-1, 9, -1, -1)], dim=-1).view(12, 81, 128)
print ((C-D).abs().max().item()) # should be 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment