Skip to content

Instantly share code, notes, and snippets.

@vatsalsaglani
Created October 8, 2022 13:46
Show Gist options
  • Save vatsalsaglani/b9d9e2608b8a0ca95ac7095e64cd447a to your computer and use it in GitHub Desktop.
Save vatsalsaglani/b9d9e2608b8a0ca95ac7095e64cd447a to your computer and use it in GitHub Desktop.
import os
from requests import head
import torch as T
import torch.nn as nn
import torch.nn.functional as F
from modules import Encoder, Decoder
class RecommendationTransformer(nn.Module):
"""Sequential recommendation model architecture
"""
def __init__(self,
vocab_size,
heads=4,
layers=6,
emb_dim=256,
pad_id=0,
num_pos=128):
super().__init__()
"""Recommendation model initializer
Args:
vocab_size (int): Number of unique tokens/items
heads (int, optional): Number of heads in the Multi-Head Self Attention Transformers (). Defaults to 4.
layers (int, optional): Number of Layers. Defaults to 6.
emb_dim (int, optional): Embedding Dimension. Defaults to 256.
pad_id (int, optional): Token used to pad tensors. Defaults to 0.
num_pos (int, optional): Positional Embedding, fixed sequence. Defaults to 128
"""
self.emb_dim = emb_dim
self.pad_id = pad_id
self.num_pos = num_pos
self.vocab_size = vocab_size
self.encoder = Encoder(source_vocab_size=vocab_size,
emb_dim=emb_dim,
layers=layers,
heads=heads,
dim_model=emb_dim,
dim_inner=4 * emb_dim,
dim_value=emb_dim,
dim_key=emb_dim,
pad_id=self.pad_id,
num_pos=num_pos)
self.rec = nn.Linear(emb_dim, vocab_size)
def forward(self, source, source_mask):
enc_op = self.encoder(source, source_mask)
op = self.rec(enc_op)
return op.permute(0, 2, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment