Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Last active July 24, 2024 08:54
Show Gist options
  • Save lzqlzzq/df638619c4fbf857d267bc2ecf5922fe to your computer and use it in GitHub Desktop.
Save lzqlzzq/df638619c4fbf857d267bc2ecf5922fe to your computer and use it in GitHub Desktop.
Multivariate embedding.
from math import sqrt
from typing import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
class MultivariateEmbedding(nn.Module):
def __init__(self,
embeddings: nn.ModuleDict,
emb_size: int,
emb_norm: bool = True):
super().__init__()
self.embeddings = embeddings
self.proj = nn.Linear(sum([e.embedding_dim for e in embeddings.values()]),
emb_size,
bias=False)
self.emb_norm = emb_norm
self.norm_layer = nn.LayerNorm(emb_size)
def forward(self,
x: OrderedDict[str, torch.Tensor]):
if(set(x.keys()) != set(self.embeddings.keys())):
raise KeyError('Input keys must be the same as the embeddings.')
embs = torch.cat([
emb(x[name]) / (sqrt(emb.embedding_dim) if self.emb_norm else 1)
for name, emb in self.embeddings.items()], dim=-1)
return self.norm_layer(self.proj(embs))
if __name__ == '__main__':
embs = nn.ModuleDict([('class_a', nn.Embedding(3, 16)),
('class_b', nn.Embedding(4, 192)),
('class_c', nn.Embedding(5, 512))])
m = MultivariateEmbedding(embs, 512)
BATCH_SIZE = 2
SAMPLE_NUM = 4
dummy_x = {
name: torch.randint(emb.num_embeddings, (BATCH_SIZE, SAMPLE_NUM))
for name, emb in embs.items()}
print(m(dummy_x).shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment