Skip to content

Instantly share code, notes, and snippets.

@YuehChuan
Created October 27, 2023 00:04
Show Gist options
  • Save YuehChuan/1daab7c49828d925294e4dd9eb813e42 to your computer and use it in GitHub Desktop.
Save YuehChuan/1daab7c49828d925294e4dd9eb813e42 to your computer and use it in GitHub Desktop.
nn.embedding backward grad
import torch
import torch.nn as nn
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(6, 3)
# a batch of 2 samples of 3 elements each
input = torch.LongTensor([[1, 2, 3],[0,2,3]])
loss=torch.sum(embedding(input))
print("embedding" )
print(embedding.weight )
print("input" )
print(input)
print("forward pass" )
print(embedding(input))
print("loss")
print(loss)
loss.backward()
print("embedding backward grad ")
print(embedding.weight.grad )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment