Skip to content

Instantly share code, notes, and snippets.

@rllin
Created September 22, 2022 01:40
Show Gist options
  • Save rllin/6728e2321fcc18448b3af793fe6f6eec to your computer and use it in GitHub Desktop.
Save rllin/6728e2321fcc18448b3af793fe6f6eec to your computer and use it in GitHub Desktop.
import os
from typing import List, Optional
import torch
from torch import distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.data import IterableDataset
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.random import RandomRecDataset
from torchrec.distributed import TrainPipelineSparseDist
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.models.dlrm import DLRM, DLRMTrain
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from tqdm import tqdm
from torchsnapshot import Snapshot
import torchrec.quant as trec_quant
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
import torchsnapshot
from torchsnapshot.io_preparer import TensorIOPreparer
from torchsnapshot.io_preparer import ShardedTensorIOPreparer
from torchsnapshot.manifest import Shard, ShardedTensorEntry
import tempfile
import functools
def start_multi(_per_rank, nprocs: int = 2):
with tempfile.TemporaryDirectory() as tmpdir:
torch.multiprocessing.start_processes(
functools.partial(_per_rank, tmpdir=tmpdir),
(nprocs,),
nprocs=nprocs,
start_method="fork",
)
def distributed(nprocs: int = 2):
def wrapper(func):
def _inner(*args, **kwargs):
def _setup_ddp(rank: int, world_size: int, tmpdir: str) -> None:
"""
Setup DDP worker.
"""
init_file = f"file://{os.path.join(tmpdir, 'init_file')}"
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
init_method=init_file,
)
func(*args, **kwargs)
dist.destroy_process_group()
start_multi(_setup_ddp, nprocs=nprocs)
return _inner
return wrapper
@distributed(4)
def hello():
print(f"{dist.get_rank()}, hello")
@distributed(1)#record
def train(
#num_embeddings: int = 10,
num_embeddings: int = 1024,
embedding_dim: int = 128,
#embedding_dim: int = 8,
) -> None:
device = torch.device('cpu')
table_names = ["feature1", "feature2"]
eb_configs = [
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
feature_names=[feature_name],
#data_type=DataType.FP16,
)
for feature_name in table_names
]
original_ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device('meta'))
original_ebc = DistributedModelParallel(module=original_ebc, device=device)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
features = KeyedJaggedTensor(
keys=["feature1", "feature2"],
values=torch.arange(num_embeddings),
lengths=torch.ones(num_embeddings).int(),
)
original_lookups = original_ebc(features).values()
torch.testing.assert_close(original_ebc(features).values(), original_lookups)
print('before save lookup', original_ebc(features).values()[0][:10])
'''
# this should be ~3.5 mb vs 14mb for unquantized
# to generate unquantized, remove the monkeypatching on ShardedTensorIOPreparer
#print(original_ebc.state_dict())
weights = torch.cat(
[
torch.flatten(_tensor.local_tensor()) for _tensor in original_ebc.state_dict().values()
]
)
#weights = original_ebc.state_dict()['embedding_bags.t_feature1.weight'].local_tensor()
weight_min, weight_max = torch.min(weights), torch.max(weights)
scale = (weight_max - weight_min) / (63 - -64)
zero_point = -64 - (weight_min / scale)
print(scale, zero_point)
'''
# function that looks at the model path and decides if module should be quantized
# in this simple example, everything has "embedding" in the path, so it's all getting quantized
def to_quantize(path: str) -> bool:
""" Returns `True` if `path` is a module that should be quantized.
"""
return "embedding" in path
# This part would be shoved into torchsnapshot as a utility
def make_custom_tensor_prepare_func(to_quantize):
quantized_dtype = torch.qint8
def custom_tensor_prepare_func(path: str, tensor:torch.Tensor, tracing:bool=False):
# in tracing mode, return a MetaTensor with correct dtype/size but do not allocate the memory
# otherwise, perform the op
if to_quantize(path):
if tracing:
return torch.tensor(tensor, dtype=quantized_dtype, device='meta')
else:
observer = torch.quantization.observer.MinMaxObserver(dtype=quantized_dtype)
observer(tensor)
scale, zero_point = observer.calculate_qparams()
print(scale, zero_point)
return torch.quantize_per_tensor(tensor, scale, zero_point, quantized_dtype)
else:
return tensor
return custom_tensor_prepare_func
import shutil
if os.path.exists('./base'):
shutil.rmtree('./base')
if os.path.exists('./quant'):
shutil.rmtree('./quant')
base_snapshot = Snapshot.take(path="./base", app_state={"model": original_ebc})
quant_snapshot = Snapshot.take(
path="./quant",
app_state={"model": original_ebc},
_custom_tensor_prepare_func=make_custom_tensor_prepare_func(to_quantize),
)
#Snapshot.take(path="./quant", app_state={"model": original_ebc}, quantize=False)
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device('meta'))
ebc = DistributedModelParallel(module=ebc, device=device)
qebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device('meta'))
qebc = DistributedModelParallel(module=ebc, device=device)
try:
torch.testing.assert_close(ebc(features).values(), original_lookups)
except AssertionError:
pass
base_snapshot.restore(app_state={"model": ebc})
print(ebc.module.state_dict()['embedding_bags.t_feature1.weight'].dtype)
quant_snapshot.restore(app_state={"model": qebc})
print(qebc.module.state_dict()['embedding_bags.t_feature1.weight'].dtype)
print('old weight', original_ebc.state_dict()['embedding_bags.t_feature1.weight'].local_tensor())
print('new weight', qebc.state_dict()['embedding_bags.t_feature1.weight'].local_tensor())
print(qebc(features).values())
print(original_lookups)
torch.testing.assert_close(original_ebc(features).values(), original_lookups)
torch.testing.assert_close(qebc(features).values(), original_lookups, rtol=1e-3, atol=1e-3)
print("*******************")
print("*******************")
from pathlib import Path
base_size = sum(f.stat().st_size for f in Path('./base').glob('**/*') if f.is_file())
quant_size = sum(f.stat().st_size for f in Path('./quant').glob('**/*') if f.is_file())
print(base_size)
print(quant_size)
print(quant_size / base_size)
if __name__ == "__main__":
train()
@yifuwang
Copy link

On line 139, I think you want to move the tensor to cpu before quantizing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment