Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active June 17, 2022 14:11
Show Gist options
  • Save krsnewwave/322b9ddb36f0bae7dd1cdf678a29b1b9 to your computer and use it in GitHub Desktop.
Save krsnewwave/322b9ddb36f0bae7dd1cdf678a29b1b9 to your computer and use it in GitHub Desktop.
import pytorch_lightning as pl
from nvtabular.framework_utils.torch.layers import ConcatenatedEmbeddings, MultiHotEmbeddings
import torch
class WideAndDeepMultihot(pl.LightningModule):
def __init__(
self,
model_conf,
cat_names, cont_names, label_names,
num_continuous,
max_output=None,
bag_mode="sum",
batch_size = None,
):
super().__init__()
# configure from model_conf
embedding_table_shapes = model_conf["embedding_table_shape"]
# truncated for simplicity... emb_dropout, layer_hidden_dims, layer_dropout_rates
# are from model_conf
mh_shapes = None
if isinstance(embedding_table_shapes, tuple):
embedding_table_shapes, mh_shapes = embedding_table_shapes
if embedding_table_shapes:
self.initial_cat_layer = ConcatenatedEmbeddings(
embedding_table_shapes, dropout=emb_dropout
)
if mh_shapes:
self.mh_cat_layer = MultiHotEmbeddings(mh_shapes, dropout=emb_dropout, mode=bag_mode)
self.initial_cont_layer = torch.nn.BatchNorm1d(num_continuous)
embedding_size = sum(emb_size for _, emb_size in embedding_table_shapes.values())
if mh_shapes is not None:
embedding_size = embedding_size + sum(emb_size for _, emb_size in mh_shapes.values())
layer_input_sizes = [embedding_size + num_continuous] + layer_hidden_dims[:-1]
layer_output_sizes = layer_hidden_dims
self.layers = torch.nn.ModuleList(
torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.ReLU(inplace=True),
torch.nn.BatchNorm1d(output_size),
torch.nn.Dropout(dropout_rate),
)
for input_size, output_size, dropout_rate in zip(
layer_input_sizes, layer_output_sizes, layer_dropout_rates
)
)
# output layer receives wide and deep
head_input_size = layer_input_sizes[0] + layer_output_sizes[-1]
self.output_layer = torch.nn.Linear(head_input_size, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment