-
-
Save attila-dusnoki-htec/cdd2e547630bf2ac785095d73ed6e120 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# requires torch==2.1.1 to avoid strange fbgemm kernel issue | |
# pip install fbgemm-gpu==0.5.0 torchrec==0.5.0 --index-url https://download.pytorch.org/whl/cpu | |
# pip install wheel torchsnapshot iopath pyre-extensions | |
# Get the weights: | |
# 1) wget, slow and maybe just for me, but the data was corrupted | |
# wget https://cloud.mlcommons.org/index.php/s/XzfSeLgW8FYfR3S/download -O weights.zip | |
# unzip weights.zip | |
# 2) CM method, much faster | |
# pip install cmind | |
# cm pull repo mlcommons@ck | |
# cm run script --tags=get,ml-model,dlrm,_pytorch,_weight_sharded,_rclone -j | |
import torch | |
from dlrm_model import EmbeddingBagCollection, DLRM_DCN, DLRMInfer, CRITEO_SYNTH_MULTIHOT_N_EMBED_PER_FEATURE, CRITEO_SYNTH_MULTIHOT_SIZES | |
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES | |
from torchrec.datasets.criteo import INT_FEATURE_COUNT, CAT_FEATURE_COUNT | |
from torchrec.modules.embedding_configs import EmbeddingBagConfig | |
from torchsnapshot import Snapshot | |
class DLRMv2_Model: | |
def __init__( | |
self, | |
model_path="CHAGE_THIS_TO_THE_FOLDER_CONTAINING_THE_UNCOMPRESSED_WEIGHTS", | |
num_embeddings_per_feature=CRITEO_SYNTH_MULTIHOT_N_EMBED_PER_FEATURE, | |
embedding_dim=128, | |
dcn_num_layers=3, | |
dcn_low_rank_dim=512, | |
dense_arch_layer_sizes=(512, 256, 128), | |
over_arch_layer_sizes=(1024, 1024, 512, 256, 1)): | |
self.model_path = model_path | |
self.num_embeddings_per_feature = list(num_embeddings_per_feature) | |
self.embedding_dim = embedding_dim | |
self.dcn_num_layers = dcn_num_layers | |
self.dcn_low_rank_dim = dcn_low_rank_dim | |
self.dense_arch_layer_sizes = list(dense_arch_layer_sizes) | |
self.over_arch_layer_sizes = list(over_arch_layer_sizes) | |
# cache model to avoid re-loading | |
self.model = None | |
self.device = torch.device("cpu") | |
def load_model(self): | |
print('Loading Model...') | |
print('create embedding_bag_configs') | |
self.embedding_bag_configs = [ | |
EmbeddingBagConfig( | |
name=f"t_{feature_name}", | |
embedding_dim=self.embedding_dim, | |
num_embeddings=self.num_embeddings_per_feature[feature_idx], | |
feature_names=[feature_name]) | |
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) | |
] | |
print('create embedding_bag_collection') | |
# create model | |
self.embedding_bag_collection = EmbeddingBagCollection( | |
tables=self.embedding_bag_configs, device=self.device) | |
print('create DLRM_DCN') | |
torchrec_dlrm_config = DLRM_DCN( | |
embedding_bag_collection=self.embedding_bag_collection, | |
dense_in_features=len(DEFAULT_INT_NAMES), | |
dense_arch_layer_sizes=self.dense_arch_layer_sizes, | |
over_arch_layer_sizes=self.over_arch_layer_sizes, | |
dcn_num_layers=self.dcn_num_layers, | |
dcn_low_rank_dim=self.dcn_low_rank_dim, | |
dense_device=self.device) | |
print('create DLRMInfer') | |
model = DLRMInfer(torchrec_dlrm_config) | |
print('create Snapshot') | |
# load weights | |
snapshot = Snapshot(path=self.model_path) | |
print('Snapshot restore') | |
snapshot.restore(app_state={"model": model}) | |
print('model eval') | |
model.eval() | |
self.model = model | |
return model | |
dlrm = DLRMv2_Model() | |
model = dlrm.load_model() | |
# Original data | |
# 0 40 42 2 54 3 0 0 2 16 0 1 4448 4 1acfe1ee 1b2ff61f 2e8b2631 6faef306 c6fc10d3 6fcd6dcb 16e08b25 670da99c 2e4e821f 5fd89f4d b21eb4c2 2974d88b bf78d0d4 52e56658 484a5e08 330c9d3e 1f7fc70b 5cc1303c 9512c20b 81ae47fc 405a6616 b9196e4d 9496de3d 1652193e 30436bfc b757e957 | |
# (batch_size, int_features(13)) | |
dense_features = torch.tensor([[ | |
3.7612002, 3.8066626, 1.609438, 4.0430512, 1.7917595, 1.0986123, 1.0986123, | |
1.609438, 2.944439, 1.0986123, 1.3862944, 8.400884, 1.9459102 | |
]], dtype=torch.float32) | |
sparse_values = [ | |
[2, 7644169, 22630677], # '0' | |
[2, 22805], # '1' | |
[2], # '2' | |
[2, 1737], # '3' | |
[2, 1500, 6032, 17547, 14405, 8897], # '4' | |
[2], # '5' | |
[2], # '6' | |
[2], # '7' | |
[2], # '8' | |
[2, 15365678, 31486386, 13326029, 1156856, 23782249, 29001523], # '9' | |
[2, 3009115, 1136654], # '10' | |
[2, 121241, 345409, 288675, 327651, 352391, 19142, 374959], # '11' | |
[2], # '12' | |
[2, 2206, 1807, 1475, 1110, 1291], # '13' | |
[2, 9049, 7584, 8766, 10106, 3914, 4024, 9722, 5154], # '14' | |
[2, 123, 102, 110, 86], # '15' | |
[2], # '16' | |
[2], # '17' | |
[2], # '18' | |
[ | |
2, 30858202, 16955799, 171930, 22262279, 20261625, 10154107, 5022018, | |
33567801, 14694455, 36048658, 15076233 | |
], # '19' | |
[ | |
2, 17928980, 22605151, 6905135, 19334652, 5951658, 9051258, 9413069, | |
2312244, 19147547, 36250416, 33306570, 1294226, 16081508, 14439041, | |
20343591, 38786290, 26010896, 24511663, 19431998, 12644296, 10975309, | |
7982202, 19355498, 39321103, 20827904, 35320618, 36648112, 9137761, | |
15008777, 19562345, 31356402, 8574314, 12178979, 5719685, 15138997, | |
34998375, 33509048, 3540628, 17321726, 36476983, 21411478, 10111681, | |
19850887, 15927126, 38349850, 13469844, 23182271, 10511383, 19672455, | |
31395166, 38749537, 10739228, 18699110, 39645126, 7400979, 15867416, | |
801119, 9339409, 29459918, 17056100, 28390158, 22394586, 10644990, | |
6047353, 36280108, 4856527, 27099832, 3690777, 13138508, 261940, | |
32323393, 30030481, 30383245, 25393665, 287495, 6319546, 15336616, | |
32339139, 4743631, 5278525, 26277952, 26932186, 28312783, 28531202, | |
3826605, 1943560, 16594040, 27573439, 9349610, 33541669, 4028561, | |
13851067, 22737828, 4091831, 13305958, 12560732, 19916223, 17279725, | |
3721895 | |
], # '20' | |
[ | |
2, 36801059, 832462, 19124375, 29812036, 37410895, 36453092, 1697722, | |
38024156, 28780396, 29574226, 32240385, 24528012, 2772530, 2438735, | |
33110534, 8662449, 10592422, 20163922, 33941757, 25984053, 769040, | |
31897814, 34539405, 20356004, 32739750, 19762904 | |
], # '21' | |
[2, 124015, 159567, 82518, 250181, 185102, 130275, 246010, 171478, | |
7379], # '22' | |
[2, 3267, 10861], # '23' | |
[2], # '24' | |
[2] # '25' | |
] | |
flattened_sparse_values = [value for array in sparse_values for value in array] | |
# (batch_size, sum_of_sizes) | |
sparse_inputs = torch.tensor([flattened_sparse_values], dtype=torch.int32) | |
# (batch_size, cat_features(26)) | |
sparse_offsets = torch.tensor( | |
[[size * (b + 1) for size in CRITEO_SYNTH_MULTIHOT_SIZES] | |
for b in range(len(sparse_inputs))], | |
dtype=torch.int32) | |
labels = torch.tensor([0], dtype=torch.int32) | |
print(f'{dense_features.shape =}, {sparse_inputs.shape=}, {sparse_offsets.shape=}') | |
result = model(dense_features, sparse_inputs, sparse_offsets) | |
print(f"actual {result = } expected {labels=}") | |
print("Export model") | |
torch.onnx.export( | |
model, (dense_features, sparse_inputs, sparse_offsets), | |
"dlrmv2.onnx", | |
input_names=["dense_features", "sparse_inputs", "sparse_offsets"], | |
output_names=["logits"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment