Skip to content

Instantly share code, notes, and snippets.

@attila-dusnoki-htec
Last active April 29, 2024 11:47
Show Gist options
  • Save attila-dusnoki-htec/26895971d35da42a00ea7f8164a18480 to your computer and use it in GitHub Desktop.
Save attila-dusnoki-htec/26895971d35da42a00ea7f8164a18480 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import torch
import torch.nn as nn
from torchrec.modules.crossnet import LowRankCrossNet
from torchrec.modules.embedding_configs import pooling_type_to_str, DataType
from torchrec.modules.embedding_modules import get_embedding_names_by_table
from torchrec.modules.mlp import MLP
from torchrec.models.dlrm import DenseArch, InteractionArch, InteractionDCNArch, InteractionProjectionArch, OverArch
from typing import Dict, List, Optional
CRITEO_SYNTH_MULTIHOT_N_EMBED_PER_FEATURE = [
40000000, 39060, 17295, 7424, 20265, 3, 7122, 1543, 63, 40000000, 3067956,
405282, 10, 2209, 11938, 155, 4, 976, 14, 40000000, 40000000, 40000000,
590152, 12973, 108, 36
]
CRITEO_SYNTH_MULTIHOT_SIZES = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 1, 12, 100, 27, 10, 3, 1, 1]
CRITEO_SYNTH_MULTIHOT_SIZES_PREFIX_SUM = [sum(CRITEO_SYNTH_MULTIHOT_SIZES[:i]) for i in range(len(CRITEO_SYNTH_MULTIHOT_SIZES))]
zero = torch.zeros(1, dtype=torch.int32)
class EmbeddingBagCollection(nn.Module):
def __init__(
self,
tables,
device=None,
):
super().__init__()
torch._C._log_api_usage_once(
f"torchrec.modules.{self.__class__.__name__}")
self.embedding_bags = nn.ModuleDict()
self._embedding_bag_configs = tables
self._lengths_per_embedding = []
self._device: torch.device = (device if device is not None else
torch.device("cpu"))
table_names = set()
for embedding_config in tables:
if embedding_config.name in table_names:
raise ValueError(
f"Duplicate table name {embedding_config.name}")
table_names.add(embedding_config.name)
dtype = (torch.float32 if embedding_config.data_type
== DataType.FP32 else torch.float16)
self.embedding_bags[embedding_config.name] = nn.EmbeddingBag(
num_embeddings=embedding_config.num_embeddings,
embedding_dim=embedding_config.embedding_dim,
mode=pooling_type_to_str(embedding_config.pooling),
device=self._device,
include_last_offset=True,
dtype=dtype,
)
if not embedding_config.feature_names:
embedding_config.feature_names = [embedding_config.name]
self._lengths_per_embedding.extend(
len(embedding_config.feature_names) *
[embedding_config.embedding_dim])
self._embedding_names = [
embedding for embeddings in get_embedding_names_by_table(tables)
for embedding in embeddings
]
self._feature_names = [table.feature_names for table in tables]
self.reset_parameters()
def forward(self, sparse_multihot_inputs, sparse_multihot_offsets):
pooled_embeddings = []
for i, embedding_bag in enumerate(self.embedding_bags.values()):
size = CRITEO_SYNTH_MULTIHOT_SIZES[i]
offset = CRITEO_SYNTH_MULTIHOT_SIZES_PREFIX_SUM[i]
print(f'{i = } {size = } {offset = }')
emb_input = sparse_multihot_inputs[:, offset:offset+size].flatten()
emb_offset = torch.cat([zero, sparse_multihot_offsets[:, i]]).flatten()
res = embedding_bag(
input=emb_input,
offsets=emb_offset,
per_sample_weights=None,
).float()
pooled_embeddings.append(res)
data = torch.cat(pooled_embeddings, dim=1)
print(f'{data = }')
return data
def embedding_bag_configs(self):
return self._embedding_bag_configs
@property
def device(self):
return self._device
def reset_parameters(self):
if (isinstance(self.device, torch.device)
and self.device.type == "meta") or (isinstance(
self.device, str) and self.device == "meta"):
return
# Initialize embedding bags weights with init_fn
for table_config in self._embedding_bag_configs:
assert table_config.init_fn is not None
param = self.embedding_bags[f"{table_config.name}"].weight
# pyre-ignore
table_config.init_fn(param)
def choose(n: int, k: int) -> int:
"""
Simple implementation of math.comb for Python 3.7 compatibility.
"""
if 0 <= k <= n:
ntok = 1
ktok = 1
for t in range(1, min(k, n - k) + 1):
ntok *= n
ktok *= t
n -= 1
return ntok // ktok
else:
return 0
class SparseArch(nn.Module):
def __init__(self, embedding_bag_collection) -> None:
super().__init__()
self.embedding_bag_collection = embedding_bag_collection
assert (self.embedding_bag_collection.embedding_bag_configs
), "Embedding bag collection cannot be empty!"
self.D = self.embedding_bag_collection.embedding_bag_configs(
)[0].embedding_dim
self._sparse_feature_names = [
name for conf in embedding_bag_collection.embedding_bag_configs()
for name in conf.feature_names
]
self.F = len(self._sparse_feature_names)
def forward(
self,
sparse_multihot_inputs,
sparse_multihot_offsets,
):
sparse_values = self.embedding_bag_collection(sparse_multihot_inputs, sparse_multihot_offsets)
print(f'{sparse_values = }')
return sparse_values.reshape(-1, self.F, self.D)
@property
def sparse_feature_names(self) -> List[str]:
return self._sparse_feature_names
class DLRM(nn.Module):
def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
dense_in_features: int,
dense_arch_layer_sizes: List[int],
over_arch_layer_sizes: List[int],
dense_device: Optional[torch.device] = None,
) -> None:
super().__init__()
assert (len(embedding_bag_collection.embedding_bag_configs())
> 0), "At least one embedding bag is required"
for i in range(1,
len(embedding_bag_collection.embedding_bag_configs())):
conf_prev = embedding_bag_collection.embedding_bag_configs()[i - 1]
conf = embedding_bag_collection.embedding_bag_configs()[i]
assert (conf_prev.embedding_dim == conf.embedding_dim
), "All EmbeddingBagConfigs must have the same dimension"
embedding_dim: int = embedding_bag_collection.embedding_bag_configs(
)[0].embedding_dim
if dense_arch_layer_sizes[-1] != embedding_dim:
raise ValueError(
f"embedding_bag_collection dimension ({embedding_dim}) and final dense "
"arch layer size ({dense_arch_layer_sizes[-1]}) must match.")
self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection)
num_sparse_features: int = len(self.sparse_arch.sparse_feature_names)
self.dense_arch = DenseArch(
in_features=dense_in_features,
layer_sizes=dense_arch_layer_sizes,
device=dense_device,
)
self.inter_arch = InteractionArch(
num_sparse_features=num_sparse_features, )
over_in_features: int = (embedding_dim +
choose(num_sparse_features, 2) +
num_sparse_features)
self.over_arch = OverArch(
in_features=over_in_features,
layer_sizes=over_arch_layer_sizes,
device=dense_device,
)
def forward(
self,
dense_features: torch.Tensor,
sparse_multihot_inputs: torch.Tensor,
sparse_multihot_offsets: torch.Tensor,
) -> torch.Tensor:
print(f'{dense_features = }')
embedded_dense = self.dense_arch(dense_features)
print(f'{embedded_dense = }')
print(f'{sparse_multihot_inputs = }')
print(f'{sparse_multihot_offsets = }')
embedded_sparse = self.sparse_arch(sparse_multihot_inputs, sparse_multihot_offsets)
print(f'{embedded_sparse = }')
concatenated_dense = self.inter_arch(dense_features=embedded_dense,
sparse_features=embedded_sparse)
print(f'{concatenated_dense = }')
logits = self.over_arch(concatenated_dense)
print(f'{logits = }')
return logits
class DLRM_DCN(DLRM):
def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
dense_in_features: int,
dense_arch_layer_sizes: List[int],
over_arch_layer_sizes: List[int],
dcn_num_layers: int,
dcn_low_rank_dim: int,
dense_device: Optional[torch.device] = None,
) -> None:
# initialize DLRM
# sparse arch and dense arch are initialized via DLRM
super().__init__(
embedding_bag_collection,
dense_in_features,
dense_arch_layer_sizes,
over_arch_layer_sizes,
dense_device,
)
embedding_dim: int = embedding_bag_collection.embedding_bag_configs(
)[0].embedding_dim
num_sparse_features: int = len(self.sparse_arch.sparse_feature_names)
# Fix interaction and over arch for DLRM_DCN
crossnet = LowRankCrossNet(
in_features=(num_sparse_features + 1) * embedding_dim,
num_layers=dcn_num_layers,
low_rank=dcn_low_rank_dim,
)
self.inter_arch = InteractionDCNArch(
num_sparse_features=num_sparse_features,
crossnet=crossnet,
)
over_in_features: int = (num_sparse_features + 1) * embedding_dim
self.over_arch = OverArch(
in_features=over_in_features,
layer_sizes=over_arch_layer_sizes,
device=dense_device,
)
class DLRMInfer(nn.Module):
def __init__(
self,
dlrm_module: DLRM_DCN,
) -> None:
super().__init__()
self.model = dlrm_module
def forward(self, dense_features: torch.Tensor,
sparse_multihot_inputs: torch.Tensor,
sparse_multihot_offsets: torch.Tensor) -> torch.Tensor:
logits = self.model(dense_features, sparse_multihot_inputs, sparse_multihot_offsets)
logits = logits.squeeze(-1)
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment