Skip to content

Instantly share code, notes, and snippets.

@c0nn3r
Last active November 1, 2018 16:41
Show Gist options
  • Save c0nn3r/5f1c5b3127e9441985799fd0fe21981a to your computer and use it in GitHub Desktop.
Save c0nn3r/5f1c5b3127e9441985799fd0fe21981a to your computer and use it in GitHub Desktop.
import torch
import itertools
def summed_permutations(matrix):
columnwise_sum = torch.sum(matrix, dim=1)
length = int(columnwise_sum.shape[0])
result = torch.cat([
torch.index_select(columnwise_sum, dim=0, index=torch.LongTensor(each))
for each in itertools.permutations(range(length), r=length)
])
return result
print(summed_permutations(torch.tensor([[1, 2], [3, 4]])))
# tensor([3, 7, 7, 3])
print(summed_permutations(torch.tensor([[1, 2, 3], [3, 4, 5], [6, 7, 8]])))
# tensor([ 6, 12, 21, 6, 21, 12, 12, 6, 21, 12, 21, 6, 21, 6, 12, 21, 12, 6])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment