Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created July 18, 2024 23:26
Show Gist options
  • Save justinchuby/0667c40bb44be94f316e0b5b2898e004 to your computer and use it in GitHub Desktop.
Save justinchuby/0667c40bb44be94f316e0b5b2898e004 to your computer and use it in GitHub Desktop.
Export quantized model to ONNX in PyTorch 2
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
example_inputs = (torch.randn(1, 5),)
m = M().eval()
# Step 1. program capture
from torch._export import capture_pre_autograd_graph
pt2e_torch_model = capture_pre_autograd_graph(m, example_inputs)
# Step 2. quantization
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, quantizer)
# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*example_inputs)
# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)
program = torch.export.export(pt2e_torch_model, example_inputs)
# we get a model with aten ops
print(program)
# Convert to ONNX
import torch_onnx
torch_onnx.patch_torch(error_report=True)
onnx_program = torch.onnx.export(program, example_inputs, "quantized.textproto")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment