Last active
May 12, 2021 00:33
-
-
Save lixiaoquan/c59b2c66d86ca7472cdcd960e290b741 to your computer and use it in GitHub Desktop.
Example to understand EmbeddingBag
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
# an Embedding module containing 10 tensors of size 3 | |
# fill with value which makes it easy to see how calculation is done | |
embedding_matrix = torch.tensor([[1.0, 2, 3], | |
[4, 5, 6], | |
[7, 8, 9], | |
[10, 11, 12], | |
[13, 14, 15], | |
[16, 17, 18], | |
[19, 20, 21], | |
[22, 23, 24], | |
[25, 26, 37], | |
[28, 29, 30], | |
]) | |
# case 0 | |
# index is 1D, offset is not None | |
# len(offsets) = 2 means there are two bags, | |
# the first bag is index[0:3] the second bag is index[4:7] | |
index = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) | |
offsets = torch.tensor([0, 4]) | |
r = F.embedding_bag(index, embedding_matrix, offsets) | |
print(r) | |
# case 1 | |
# Each bag can have different size, offset[0, 1] means size of first bag is 1, and | |
# size of 2rd bag is 7 | |
offsets = torch.tensor([0, 1]) | |
r = F.embedding_bag(index, embedding_matrix, offsets) | |
print(r) | |
# case 2 | |
# Each bag can have different size, offset[0, 1] means size of first bag is 1, and | |
# size of 2rd bag is 7 | |
# include_last_offset is used to tell the api to use offsets[-1] to compute size of last bag | |
# see https://github.com/pytorch/pytorch/issues/29019 | |
# case 2 is equivalent to case 1, the two bags are [0:0] [1:8], with include_last_offset, | |
# 8 in [1:8] becomes explicit | |
offsets = torch.tensor([0, 1, 8]) | |
r = F.embedding_bag(index, embedding_matrix, offsets, include_last_offset=True) | |
print(r) | |
# case 3 | |
# index is 2D, offset is None | |
# index_dim[0] == 2 means there are 2 bags, index_dim[1] == 4 means the size of each bag is 4 | |
# Each bag must has the same size, case 2 is equivalent with case 0 | |
# But we can't use 2D index to represent case 1 | |
index = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) | |
r = F.embedding_bag(index, embedding_matrix) | |
print(r) | |
# case 4 | |
# Use F.embedding to watch each bag before sum/mean/min | |
index = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) | |
r = F.embedding(index, embedding_matrix) | |
print(r) | |
# reduce in each bag, in the first bag, there are 4 vectors, they will be reduced with `mode` | |
r = torch.mean(r, dim=-2, keepdim=True) | |
print(r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment