Last active
July 24, 2024 08:54
-
-
Save lzqlzzq/df638619c4fbf857d267bc2ecf5922fe to your computer and use it in GitHub Desktop.
Multivariate embedding.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import sqrt | |
from typing import OrderedDict | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class MultivariateEmbedding(nn.Module): | |
def __init__(self, | |
embeddings: nn.ModuleDict, | |
emb_size: int, | |
emb_norm: bool = True): | |
super().__init__() | |
self.embeddings = embeddings | |
self.proj = nn.Linear(sum([e.embedding_dim for e in embeddings.values()]), | |
emb_size, | |
bias=False) | |
self.emb_norm = emb_norm | |
self.norm_layer = nn.LayerNorm(emb_size) | |
def forward(self, | |
x: OrderedDict[str, torch.Tensor]): | |
if(set(x.keys()) != set(self.embeddings.keys())): | |
raise KeyError('Input keys must be the same as the embeddings.') | |
embs = torch.cat([ | |
emb(x[name]) / (sqrt(emb.embedding_dim) if self.emb_norm else 1) | |
for name, emb in self.embeddings.items()], dim=-1) | |
return self.norm_layer(self.proj(embs)) | |
if __name__ == '__main__': | |
embs = nn.ModuleDict([('class_a', nn.Embedding(3, 16)), | |
('class_b', nn.Embedding(4, 192)), | |
('class_c', nn.Embedding(5, 512))]) | |
m = MultivariateEmbedding(embs, 512) | |
BATCH_SIZE = 2 | |
SAMPLE_NUM = 4 | |
dummy_x = { | |
name: torch.randint(emb.num_embeddings, (BATCH_SIZE, SAMPLE_NUM)) | |
for name, emb in embs.items()} | |
print(m(dummy_x).shape) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment