-
-
Save attila-dusnoki-htec/26895971d35da42a00ea7f8164a18480 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import torch | |
import torch.nn as nn | |
from torchrec.modules.crossnet import LowRankCrossNet | |
from torchrec.modules.embedding_configs import pooling_type_to_str, DataType | |
from torchrec.modules.embedding_modules import get_embedding_names_by_table | |
from torchrec.modules.mlp import MLP | |
from torchrec.models.dlrm import DenseArch, InteractionArch, InteractionDCNArch, InteractionProjectionArch, OverArch | |
from typing import Dict, List, Optional | |
CRITEO_SYNTH_MULTIHOT_N_EMBED_PER_FEATURE = [ | |
40000000, 39060, 17295, 7424, 20265, 3, 7122, 1543, 63, 40000000, 3067956, | |
405282, 10, 2209, 11938, 155, 4, 976, 14, 40000000, 40000000, 40000000, | |
590152, 12973, 108, 36 | |
] | |
CRITEO_SYNTH_MULTIHOT_SIZES = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 1, 12, 100, 27, 10, 3, 1, 1] | |
CRITEO_SYNTH_MULTIHOT_SIZES_PREFIX_SUM = [sum(CRITEO_SYNTH_MULTIHOT_SIZES[:i]) for i in range(len(CRITEO_SYNTH_MULTIHOT_SIZES))] | |
zero = torch.zeros(1, dtype=torch.int32) | |
class EmbeddingBagCollection(nn.Module): | |
def __init__( | |
self, | |
tables, | |
device=None, | |
): | |
super().__init__() | |
torch._C._log_api_usage_once( | |
f"torchrec.modules.{self.__class__.__name__}") | |
self.embedding_bags = nn.ModuleDict() | |
self._embedding_bag_configs = tables | |
self._lengths_per_embedding = [] | |
self._device: torch.device = (device if device is not None else | |
torch.device("cpu")) | |
table_names = set() | |
for embedding_config in tables: | |
if embedding_config.name in table_names: | |
raise ValueError( | |
f"Duplicate table name {embedding_config.name}") | |
table_names.add(embedding_config.name) | |
dtype = (torch.float32 if embedding_config.data_type | |
== DataType.FP32 else torch.float16) | |
self.embedding_bags[embedding_config.name] = nn.EmbeddingBag( | |
num_embeddings=embedding_config.num_embeddings, | |
embedding_dim=embedding_config.embedding_dim, | |
mode=pooling_type_to_str(embedding_config.pooling), | |
device=self._device, | |
include_last_offset=True, | |
dtype=dtype, | |
) | |
if not embedding_config.feature_names: | |
embedding_config.feature_names = [embedding_config.name] | |
self._lengths_per_embedding.extend( | |
len(embedding_config.feature_names) * | |
[embedding_config.embedding_dim]) | |
self._embedding_names = [ | |
embedding for embeddings in get_embedding_names_by_table(tables) | |
for embedding in embeddings | |
] | |
self._feature_names = [table.feature_names for table in tables] | |
self.reset_parameters() | |
def forward(self, sparse_multihot_inputs, sparse_multihot_offsets): | |
pooled_embeddings = [] | |
for i, embedding_bag in enumerate(self.embedding_bags.values()): | |
size = CRITEO_SYNTH_MULTIHOT_SIZES[i] | |
offset = CRITEO_SYNTH_MULTIHOT_SIZES_PREFIX_SUM[i] | |
print(f'{i = } {size = } {offset = }') | |
emb_input = sparse_multihot_inputs[:, offset:offset+size].flatten() | |
emb_offset = torch.cat([zero, sparse_multihot_offsets[:, i]]).flatten() | |
res = embedding_bag( | |
input=emb_input, | |
offsets=emb_offset, | |
per_sample_weights=None, | |
).float() | |
pooled_embeddings.append(res) | |
data = torch.cat(pooled_embeddings, dim=1) | |
print(f'{data = }') | |
return data | |
def embedding_bag_configs(self): | |
return self._embedding_bag_configs | |
@property | |
def device(self): | |
return self._device | |
def reset_parameters(self): | |
if (isinstance(self.device, torch.device) | |
and self.device.type == "meta") or (isinstance( | |
self.device, str) and self.device == "meta"): | |
return | |
# Initialize embedding bags weights with init_fn | |
for table_config in self._embedding_bag_configs: | |
assert table_config.init_fn is not None | |
param = self.embedding_bags[f"{table_config.name}"].weight | |
# pyre-ignore | |
table_config.init_fn(param) | |
def choose(n: int, k: int) -> int: | |
""" | |
Simple implementation of math.comb for Python 3.7 compatibility. | |
""" | |
if 0 <= k <= n: | |
ntok = 1 | |
ktok = 1 | |
for t in range(1, min(k, n - k) + 1): | |
ntok *= n | |
ktok *= t | |
n -= 1 | |
return ntok // ktok | |
else: | |
return 0 | |
class SparseArch(nn.Module): | |
def __init__(self, embedding_bag_collection) -> None: | |
super().__init__() | |
self.embedding_bag_collection = embedding_bag_collection | |
assert (self.embedding_bag_collection.embedding_bag_configs | |
), "Embedding bag collection cannot be empty!" | |
self.D = self.embedding_bag_collection.embedding_bag_configs( | |
)[0].embedding_dim | |
self._sparse_feature_names = [ | |
name for conf in embedding_bag_collection.embedding_bag_configs() | |
for name in conf.feature_names | |
] | |
self.F = len(self._sparse_feature_names) | |
def forward( | |
self, | |
sparse_multihot_inputs, | |
sparse_multihot_offsets, | |
): | |
sparse_values = self.embedding_bag_collection(sparse_multihot_inputs, sparse_multihot_offsets) | |
print(f'{sparse_values = }') | |
return sparse_values.reshape(-1, self.F, self.D) | |
@property | |
def sparse_feature_names(self) -> List[str]: | |
return self._sparse_feature_names | |
class DLRM(nn.Module): | |
def __init__( | |
self, | |
embedding_bag_collection: EmbeddingBagCollection, | |
dense_in_features: int, | |
dense_arch_layer_sizes: List[int], | |
over_arch_layer_sizes: List[int], | |
dense_device: Optional[torch.device] = None, | |
) -> None: | |
super().__init__() | |
assert (len(embedding_bag_collection.embedding_bag_configs()) | |
> 0), "At least one embedding bag is required" | |
for i in range(1, | |
len(embedding_bag_collection.embedding_bag_configs())): | |
conf_prev = embedding_bag_collection.embedding_bag_configs()[i - 1] | |
conf = embedding_bag_collection.embedding_bag_configs()[i] | |
assert (conf_prev.embedding_dim == conf.embedding_dim | |
), "All EmbeddingBagConfigs must have the same dimension" | |
embedding_dim: int = embedding_bag_collection.embedding_bag_configs( | |
)[0].embedding_dim | |
if dense_arch_layer_sizes[-1] != embedding_dim: | |
raise ValueError( | |
f"embedding_bag_collection dimension ({embedding_dim}) and final dense " | |
"arch layer size ({dense_arch_layer_sizes[-1]}) must match.") | |
self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection) | |
num_sparse_features: int = len(self.sparse_arch.sparse_feature_names) | |
self.dense_arch = DenseArch( | |
in_features=dense_in_features, | |
layer_sizes=dense_arch_layer_sizes, | |
device=dense_device, | |
) | |
self.inter_arch = InteractionArch( | |
num_sparse_features=num_sparse_features, ) | |
over_in_features: int = (embedding_dim + | |
choose(num_sparse_features, 2) + | |
num_sparse_features) | |
self.over_arch = OverArch( | |
in_features=over_in_features, | |
layer_sizes=over_arch_layer_sizes, | |
device=dense_device, | |
) | |
def forward( | |
self, | |
dense_features: torch.Tensor, | |
sparse_multihot_inputs: torch.Tensor, | |
sparse_multihot_offsets: torch.Tensor, | |
) -> torch.Tensor: | |
print(f'{dense_features = }') | |
embedded_dense = self.dense_arch(dense_features) | |
print(f'{embedded_dense = }') | |
print(f'{sparse_multihot_inputs = }') | |
print(f'{sparse_multihot_offsets = }') | |
embedded_sparse = self.sparse_arch(sparse_multihot_inputs, sparse_multihot_offsets) | |
print(f'{embedded_sparse = }') | |
concatenated_dense = self.inter_arch(dense_features=embedded_dense, | |
sparse_features=embedded_sparse) | |
print(f'{concatenated_dense = }') | |
logits = self.over_arch(concatenated_dense) | |
print(f'{logits = }') | |
return logits | |
class DLRM_DCN(DLRM): | |
def __init__( | |
self, | |
embedding_bag_collection: EmbeddingBagCollection, | |
dense_in_features: int, | |
dense_arch_layer_sizes: List[int], | |
over_arch_layer_sizes: List[int], | |
dcn_num_layers: int, | |
dcn_low_rank_dim: int, | |
dense_device: Optional[torch.device] = None, | |
) -> None: | |
# initialize DLRM | |
# sparse arch and dense arch are initialized via DLRM | |
super().__init__( | |
embedding_bag_collection, | |
dense_in_features, | |
dense_arch_layer_sizes, | |
over_arch_layer_sizes, | |
dense_device, | |
) | |
embedding_dim: int = embedding_bag_collection.embedding_bag_configs( | |
)[0].embedding_dim | |
num_sparse_features: int = len(self.sparse_arch.sparse_feature_names) | |
# Fix interaction and over arch for DLRM_DCN | |
crossnet = LowRankCrossNet( | |
in_features=(num_sparse_features + 1) * embedding_dim, | |
num_layers=dcn_num_layers, | |
low_rank=dcn_low_rank_dim, | |
) | |
self.inter_arch = InteractionDCNArch( | |
num_sparse_features=num_sparse_features, | |
crossnet=crossnet, | |
) | |
over_in_features: int = (num_sparse_features + 1) * embedding_dim | |
self.over_arch = OverArch( | |
in_features=over_in_features, | |
layer_sizes=over_arch_layer_sizes, | |
device=dense_device, | |
) | |
class DLRMInfer(nn.Module): | |
def __init__( | |
self, | |
dlrm_module: DLRM_DCN, | |
) -> None: | |
super().__init__() | |
self.model = dlrm_module | |
def forward(self, dense_features: torch.Tensor, | |
sparse_multihot_inputs: torch.Tensor, | |
sparse_multihot_offsets: torch.Tensor) -> torch.Tensor: | |
logits = self.model(dense_features, sparse_multihot_inputs, sparse_multihot_offsets) | |
logits = logits.squeeze(-1) | |
return logits |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment