Skip to content

Instantly share code, notes, and snippets.

@erip
Last active August 5, 2022 12:29
Show Gist options
  • Save erip/b1d5badfa7c756073c44c6998422763f to your computer and use it in GitHub Desktop.
Save erip/b1d5badfa7c756073c44c6998422763f to your computer and use it in GitHub Desktop.
Testing whether embedding bag's weights can be tied with embedding layer
#!/usr/bin/env python3
import torch
import torch.nn as nn
if __name__ == "__main__":
V, max_seq, padding_idx, emb_dim, B = 10, 100, 1, 512, 32
emb_layer = nn.Embedding(V, emb_dim, padding_idx=padding_idx)
emb_bag = nn.EmbeddingBag.from_pretrained(emb_layer.weight, freeze=False, padding_idx=padding_idx)
initial_weights = emb_layer.weight.detach()
assert not initial_weights.requires_grad
tokens = torch.randint(0, V, (B, max_seq))
y = torch.randn((B, emb_dim))
loss = nn.MSELoss()
y_ = emb_bag(tokens)
l = loss(y_, y)
assert emb_bag.weight.grad is None
l.backward()
assert emb_bag.weight.grad is not None
# The following assertion fails
assert emb_layer.weight.grad is not None and torch.allclose(emb_bag.weight.grad, emb_layer.weight.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment