Created
September 6, 2023 01:45
-
-
Save leslie-fang-intel/5783540cd11e9c132630319bde16a422 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Execution Command: | |
import torch | |
import torchvision.models as models | |
import torch._dynamo as torchdynamo | |
import copy | |
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | |
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq | |
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
import random | |
import numpy as np | |
from torch._export import capture_pre_autograd_graph, dynamic_dim | |
import time | |
random.seed(2023) | |
torch.manual_seed(2023) | |
np.random.seed(2023) | |
def run_model(model_name): | |
print("start int8 test of model: {}".format(model_name), flush=True) | |
traced_bs = 50 | |
model = models.__dict__[model_name](pretrained=True).eval() | |
x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) | |
example_inputs = (x,) | |
start = time.time() | |
with torch.no_grad(): | |
# Generate the FX Module | |
exported_model = capture_pre_autograd_graph( | |
model, | |
example_inputs | |
) | |
print("exported_model is: {}".format(exported_model), flush=True) | |
# Create X86InductorQuantizer | |
quantizer = X86InductorQuantizer() | |
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) | |
# PT2E Quantization flow | |
prepared_model = prepare_pt2e(exported_model, quantizer) | |
print("prepared_model is: {}".format(prepared_model), flush=True) | |
# Calibration | |
prepared_model(x) | |
converted_model = convert_pt2e(prepared_model).eval() | |
print("converted_model is: {}".format(converted_model), flush=True) | |
end = time.time() | |
print('Time consumptation is: {} second'.format((end - start))) | |
if __name__ == "__main__": | |
model_list = ["resnet152"] | |
for model in model_list: | |
run_model(model) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment