Skip to content

Instantly share code, notes, and snippets.

@attila-dusnoki-htec
Last active April 29, 2024 11:56
Show Gist options
  • Save attila-dusnoki-htec/cdd2e547630bf2ac785095d73ed6e120 to your computer and use it in GitHub Desktop.
Save attila-dusnoki-htec/cdd2e547630bf2ac785095d73ed6e120 to your computer and use it in GitHub Desktop.
# requires torch==2.1.1 to avoid strange fbgemm kernel issue
# pip install fbgemm-gpu==0.5.0 torchrec==0.5.0 --index-url https://download.pytorch.org/whl/cpu
# pip install wheel torchsnapshot iopath pyre-extensions
# Get the weights:
# 1) wget, slow and maybe just for me, but the data was corrupted
# wget https://cloud.mlcommons.org/index.php/s/XzfSeLgW8FYfR3S/download -O weights.zip
# unzip weights.zip
# 2) CM method, much faster
# pip install cmind
# cm pull repo mlcommons@ck
# cm run script --tags=get,ml-model,dlrm,_pytorch,_weight_sharded,_rclone -j
import torch
from dlrm_model import EmbeddingBagCollection, DLRM_DCN, DLRMInfer, CRITEO_SYNTH_MULTIHOT_N_EMBED_PER_FEATURE, CRITEO_SYNTH_MULTIHOT_SIZES
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.criteo import INT_FEATURE_COUNT, CAT_FEATURE_COUNT
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchsnapshot import Snapshot
class DLRMv2_Model:
def __init__(
self,
model_path="CHAGE_THIS_TO_THE_FOLDER_CONTAINING_THE_UNCOMPRESSED_WEIGHTS",
num_embeddings_per_feature=CRITEO_SYNTH_MULTIHOT_N_EMBED_PER_FEATURE,
embedding_dim=128,
dcn_num_layers=3,
dcn_low_rank_dim=512,
dense_arch_layer_sizes=(512, 256, 128),
over_arch_layer_sizes=(1024, 1024, 512, 256, 1)):
self.model_path = model_path
self.num_embeddings_per_feature = list(num_embeddings_per_feature)
self.embedding_dim = embedding_dim
self.dcn_num_layers = dcn_num_layers
self.dcn_low_rank_dim = dcn_low_rank_dim
self.dense_arch_layer_sizes = list(dense_arch_layer_sizes)
self.over_arch_layer_sizes = list(over_arch_layer_sizes)
# cache model to avoid re-loading
self.model = None
self.device = torch.device("cpu")
def load_model(self):
print('Loading Model...')
print('create embedding_bag_configs')
self.embedding_bag_configs = [
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.embedding_dim,
num_embeddings=self.num_embeddings_per_feature[feature_idx],
feature_names=[feature_name])
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
]
print('create embedding_bag_collection')
# create model
self.embedding_bag_collection = EmbeddingBagCollection(
tables=self.embedding_bag_configs, device=self.device)
print('create DLRM_DCN')
torchrec_dlrm_config = DLRM_DCN(
embedding_bag_collection=self.embedding_bag_collection,
dense_in_features=len(DEFAULT_INT_NAMES),
dense_arch_layer_sizes=self.dense_arch_layer_sizes,
over_arch_layer_sizes=self.over_arch_layer_sizes,
dcn_num_layers=self.dcn_num_layers,
dcn_low_rank_dim=self.dcn_low_rank_dim,
dense_device=self.device)
print('create DLRMInfer')
model = DLRMInfer(torchrec_dlrm_config)
print('create Snapshot')
# load weights
snapshot = Snapshot(path=self.model_path)
print('Snapshot restore')
snapshot.restore(app_state={"model": model})
print('model eval')
model.eval()
self.model = model
return model
dlrm = DLRMv2_Model()
model = dlrm.load_model()
# Original data
# 0 40 42 2 54 3 0 0 2 16 0 1 4448 4 1acfe1ee 1b2ff61f 2e8b2631 6faef306 c6fc10d3 6fcd6dcb 16e08b25 670da99c 2e4e821f 5fd89f4d b21eb4c2 2974d88b bf78d0d4 52e56658 484a5e08 330c9d3e 1f7fc70b 5cc1303c 9512c20b 81ae47fc 405a6616 b9196e4d 9496de3d 1652193e 30436bfc b757e957
# (batch_size, int_features(13))
dense_features = torch.tensor([[
3.7612002, 3.8066626, 1.609438, 4.0430512, 1.7917595, 1.0986123, 1.0986123,
1.609438, 2.944439, 1.0986123, 1.3862944, 8.400884, 1.9459102
]], dtype=torch.float32)
sparse_values = [
[2, 7644169, 22630677], # '0'
[2, 22805], # '1'
[2], # '2'
[2, 1737], # '3'
[2, 1500, 6032, 17547, 14405, 8897], # '4'
[2], # '5'
[2], # '6'
[2], # '7'
[2], # '8'
[2, 15365678, 31486386, 13326029, 1156856, 23782249, 29001523], # '9'
[2, 3009115, 1136654], # '10'
[2, 121241, 345409, 288675, 327651, 352391, 19142, 374959], # '11'
[2], # '12'
[2, 2206, 1807, 1475, 1110, 1291], # '13'
[2, 9049, 7584, 8766, 10106, 3914, 4024, 9722, 5154], # '14'
[2, 123, 102, 110, 86], # '15'
[2], # '16'
[2], # '17'
[2], # '18'
[
2, 30858202, 16955799, 171930, 22262279, 20261625, 10154107, 5022018,
33567801, 14694455, 36048658, 15076233
], # '19'
[
2, 17928980, 22605151, 6905135, 19334652, 5951658, 9051258, 9413069,
2312244, 19147547, 36250416, 33306570, 1294226, 16081508, 14439041,
20343591, 38786290, 26010896, 24511663, 19431998, 12644296, 10975309,
7982202, 19355498, 39321103, 20827904, 35320618, 36648112, 9137761,
15008777, 19562345, 31356402, 8574314, 12178979, 5719685, 15138997,
34998375, 33509048, 3540628, 17321726, 36476983, 21411478, 10111681,
19850887, 15927126, 38349850, 13469844, 23182271, 10511383, 19672455,
31395166, 38749537, 10739228, 18699110, 39645126, 7400979, 15867416,
801119, 9339409, 29459918, 17056100, 28390158, 22394586, 10644990,
6047353, 36280108, 4856527, 27099832, 3690777, 13138508, 261940,
32323393, 30030481, 30383245, 25393665, 287495, 6319546, 15336616,
32339139, 4743631, 5278525, 26277952, 26932186, 28312783, 28531202,
3826605, 1943560, 16594040, 27573439, 9349610, 33541669, 4028561,
13851067, 22737828, 4091831, 13305958, 12560732, 19916223, 17279725,
3721895
], # '20'
[
2, 36801059, 832462, 19124375, 29812036, 37410895, 36453092, 1697722,
38024156, 28780396, 29574226, 32240385, 24528012, 2772530, 2438735,
33110534, 8662449, 10592422, 20163922, 33941757, 25984053, 769040,
31897814, 34539405, 20356004, 32739750, 19762904
], # '21'
[2, 124015, 159567, 82518, 250181, 185102, 130275, 246010, 171478,
7379], # '22'
[2, 3267, 10861], # '23'
[2], # '24'
[2] # '25'
]
flattened_sparse_values = [value for array in sparse_values for value in array]
# (batch_size, sum_of_sizes)
sparse_inputs = torch.tensor([flattened_sparse_values], dtype=torch.int32)
# (batch_size, cat_features(26))
sparse_offsets = torch.tensor(
[[size * (b + 1) for size in CRITEO_SYNTH_MULTIHOT_SIZES]
for b in range(len(sparse_inputs))],
dtype=torch.int32)
labels = torch.tensor([0], dtype=torch.int32)
print(f'{dense_features.shape =}, {sparse_inputs.shape=}, {sparse_offsets.shape=}')
result = model(dense_features, sparse_inputs, sparse_offsets)
print(f"actual {result = } expected {labels=}")
print("Export model")
torch.onnx.export(
model, (dense_features, sparse_inputs, sparse_offsets),
"dlrmv2.onnx",
input_names=["dense_features", "sparse_inputs", "sparse_offsets"],
output_names=["logits"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment