Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leslie-fang-intel/ba5f37035bf63fb787a4f831fa09f911 to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/ba5f37035bf63fb787a4f831fa09f911 to your computer and use it in GitHub Desktop.
import torch
import torchvision
import torch._dynamo as torchdynamo
import copy
from torch.ao.quantization._pt2e.quantizer import (
QNNPackQuantizer,
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
prepare_qat_pt2e_quantizer,
)
def pytorch_pt2e_qat(model_fp, data):
example_inputs = (data, )
m, guards = torchdynamo.export(
model_fp,
*copy.deepcopy(example_inputs),
aten_graph=True
)
before_fusion_result = m(*example_inputs)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
quantizer.set_global(qq.get_symmetric_quantization_config(is_per_channel=True, is_qat=True))
# Insert Observer
m = prepare_qat_pt2e_quantizer(m, quantizer)
print("prepared model is: {}".format(m), flush=True)
from torch.fx.passes.graph_drawer import FxGraphDrawer
g = FxGraphDrawer(m, "resnet50")
g.get_dot_graph().write_svg("./rn50_qat_pt2e_prepare.svg")
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
print("converted model is: {}".format(m), flush=True)
# m.eval()
# with torch.no_grad():
# traced_model = torch.jit.trace(m, data, check_trace=False)
# traced_model = torch.jit.freeze(traced_model)
# y = traced_model(data)
# y = traced_model(data)
# graph = traced_model.graph_for(data)
if __name__ == "__main__":
data = torch.randn(1, 3, 224, 224)
model_fp = torchvision.models.resnet50(pretrained=True)
print("--------------PyTorch 2.0 QAT -----------")
pytorch_pt2e_qat(model_fp, data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment