Skip to content

Instantly share code, notes, and snippets.

@wangkuiyi
Created April 25, 2021 01:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wangkuiyi/dd2e3794d11010f0cd562ed009664f90 to your computer and use it in GitHub Desktop.
Save wangkuiyi/dd2e3794d11010f0cd562ed009664f90 to your computer and use it in GitHub Desktop.
torch.nn.AssocEmb
import torch
class AssocEmb(torch.nn.Module):
def __init__(self, dim: list[int]):
super(AssocEmb, self).__init__()
self.dim = dim
self.tbl = torch.nn.ParameterDict()
def forward(self, idx: str) -> torch.Tensor:
if idx not in self.tbl:
p = torch.rand(self.dim, requires_grad=True)
self.tbl[idx] = torch.nn.Parameter(p)
return self.tbl[idx]
m = AssocEmb([2, 2])
x = m("apple")
y = m("orange")
z = x * y
c = z.sum()
c.backward()
for k, v in enumerate(m.tbl):
print(v, m.tbl[v], m.tbl[v].grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment