Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
krsnewwave / vae_pytorch.py
Last active May 23, 2022 14:57
vae recommender implementation on pytorch lightning
class MVAERecommender(TopNRecommender):
# TopNRecommender contains methods to predict top k
def __init__(self, model_conf : Dict, novelty_per_item, num_users, num_items, remove_observed = False, ):
# ... configuration is skipped
# # # # Model Structure # # # #
# this is to handle encoding dimensions as lists
self.encoder = nn.ModuleList()
# this enumeration produces dims in pairs, start with 1
for i, (d_in, d_out) in enumerate(zip(self.enc_dims[:-1], self.enc_dims[1:]), start=1):
# double d out at last for the mean and variance parameters
import pytorch_lightning as pl
from nvtabular.framework_utils.torch.layers import ConcatenatedEmbeddings, MultiHotEmbeddings
import torch
class WideAndDeepMultihot(pl.LightningModule):
def __init__(
self,
model_conf,
cat_names, cont_names, label_names,
num_continuous,
class WideAndDeepMultihot(pl.LightningModule):
# others go here...
#
#
def training_step(self, batch, batch_idx):
# unpack
x_cat, x_cont, y = self.transform.transform_with_label(batch)
# forward
y_pred = self((x_cat, x_cont))
# (1) Create loaders
def create_loaders(train_dataset, valid_dataset):
# dataset and loaders
train_iter = TorchAsyncItr(
train_dataset,
batch_size=BATCH_SIZE,
cats=CATEGORICAL_COLUMNS + CATEGORICAL_MH_COLUMNS,
conts=NUMERIC_COLUMNS,
labels=["rating"],
)
import gc
def objective(trial):
### Dataset section
# see https://gist.github.com/krsnewwave/273b9cafa4813771791f076cee32c2e4#file-nvtabular_movielens_main_loop_functions-py-L2
train_loader, valid_loader = create_loaders(train_dataset, valid_dataset)
### Model section
# see https://gist.github.com/krsnewwave/273b9cafa4813771791f076cee32c2e4#file-nvtabular_movielens_main_loop_functions-py-L29
epochs = 1
patience = 3