Skip to content

Instantly share code, notes, and snippets.

@lixiaoquan
Last active May 12, 2021 00:33
Show Gist options
  • Save lixiaoquan/c59b2c66d86ca7472cdcd960e290b741 to your computer and use it in GitHub Desktop.
Save lixiaoquan/c59b2c66d86ca7472cdcd960e290b741 to your computer and use it in GitHub Desktop.
Example to understand EmbeddingBag
import torch
import torch.nn.functional as F
# an Embedding module containing 10 tensors of size 3
# fill with value which makes it easy to see how calculation is done
embedding_matrix = torch.tensor([[1.0, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
[19, 20, 21],
[22, 23, 24],
[25, 26, 37],
[28, 29, 30],
])
# case 0
# index is 1D, offset is not None
# len(offsets) = 2 means there are two bags,
# the first bag is index[0:3] the second bag is index[4:7]
index = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.tensor([0, 4])
r = F.embedding_bag(index, embedding_matrix, offsets)
print(r)
# case 1
# Each bag can have different size, offset[0, 1] means size of first bag is 1, and
# size of 2rd bag is 7
offsets = torch.tensor([0, 1])
r = F.embedding_bag(index, embedding_matrix, offsets)
print(r)
# case 2
# Each bag can have different size, offset[0, 1] means size of first bag is 1, and
# size of 2rd bag is 7
# include_last_offset is used to tell the api to use offsets[-1] to compute size of last bag
# see https://github.com/pytorch/pytorch/issues/29019
# case 2 is equivalent to case 1, the two bags are [0:0] [1:8], with include_last_offset,
# 8 in [1:8] becomes explicit
offsets = torch.tensor([0, 1, 8])
r = F.embedding_bag(index, embedding_matrix, offsets, include_last_offset=True)
print(r)
# case 3
# index is 2D, offset is None
# index_dim[0] == 2 means there are 2 bags, index_dim[1] == 4 means the size of each bag is 4
# Each bag must has the same size, case 2 is equivalent with case 0
# But we can't use 2D index to represent case 1
index = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
r = F.embedding_bag(index, embedding_matrix)
print(r)
# case 4
# Use F.embedding to watch each bag before sum/mean/min
index = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
r = F.embedding(index, embedding_matrix)
print(r)
# reduce in each bag, in the first bag, there are 4 vectors, they will be reduced with `mode`
r = torch.mean(r, dim=-2, keepdim=True)
print(r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment