Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Last active May 9, 2024 00:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save leslie-fang-intel/b78ed682aa9b54d2608285c5a4897cfc to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/b78ed682aa9b54d2608285c5a4897cfc to your computer and use it in GitHub Desktop.
Toy Example code for Quantization in PyTorch 2.0 Export Tutorial
import copy
import itertools
import operator
from typing import Callable, Dict, List, Optional, Set, Any
import torch
import torch._dynamo as torchdynamo
from torch.ao.quantization._pt2e.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
get_input_act_qspec,
get_output_act_qspec,
get_bias_qspec,
get_weight_qspec,
)
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from torch.ao.quantization._pt2e.quantizer.quantizer import (
OperatorConfig,
QuantizationConfig,
QuantizationSpec,
Quantizer,
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.ao.quantization.observer import (
HistogramObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
import torchvision
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
prepare_pt2e_quantizer,
)
def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if node is not None:
if "quantization_annotation" not in node.meta:
node.meta["quantization_annotation"] = QuantizationAnnotation()
node.meta["quantization_annotation"]._annotated = True
def _is_annotated(nodes: List[Node]):
annotated = False
for node in nodes:
annotated = annotated or (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
)
return annotated
class BackendQuantizer(Quantizer):
def __init__(self):
super().__init__()
self.global_config: QuantizationConfig = None # type: ignore[assignment]
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
def set_global(self, quantization_config: QuantizationConfig):
"""set global QuantizationConfig used for the backend.
QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py.
"""
self.global_config = quantization_config
return self
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""annotate nodes in the graph with observer or fake quant constructors
to convey the desired way of quantization.
"""
global_config = self.global_config
self.annotate_symmetric_config(model, global_config)
return model
def annotate_symmetric_config(
self, model: torch.fx.GraphModule, config: QuantizationConfig
) -> torch.fx.GraphModule:
self._annotate_linear(model, config)
self._annotate_conv2d(model, config)
self._annotate_maxpool2d(model, config)
return model
def _annotate_conv2d(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
conv_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
)
conv_partitions = list(itertools.chain(*conv_partitions.values()))
for conv_partition in conv_partitions:
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
continue
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
)
def _annotate_linear(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
)
act_qspec = get_input_act_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
bias_qspec = get_bias_qspec(quantization_config)
for module_or_fn_type, partitions in module_partitions.items():
if module_or_fn_type == torch.nn.Linear:
for p in partitions:
act_node = p.input_nodes[0]
output_node = p.output_nodes[0]
weight_node = None
bias_node = None
for node in p.params:
weight_or_bias = getattr(gm, node.target) # type: ignore[arg-type]
if weight_or_bias.ndim == 2: # type: ignore[attr-defined]
weight_node = node
if weight_or_bias.ndim == 1: # type: ignore[attr-defined]
bias_node = node
if weight_node is None:
raise ValueError("No weight found in Linear pattern")
# find use of act node within the matched pattern
act_use_node = None
for node in p.nodes:
if node in act_node.users: # type: ignore[union-attr]
act_use_node = node
break
if act_use_node is None:
raise ValueError(
"Could not find an user of act node within matched pattern."
)
if _is_annotated([act_use_node]) is False: # type: ignore[list-item]
_annotate_input_qspec_map(
act_use_node,
act_node,
act_qspec,
)
if bias_node and _is_annotated([bias_node]) is False:
_annotate_output_qspec(bias_node, bias_qspec)
if _is_annotated([weight_node]) is False: # type: ignore[list-item]
_annotate_output_qspec(weight_node, weight_qspec)
if _is_annotated([output_node]) is False:
_annotate_output_qspec(output_node, act_qspec)
nodes_to_mark_annotated = list(p.nodes)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def _annotate_maxpool2d(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
module_partitions = get_source_partitions(
gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d]
)
maxpool_partitions = list(itertools.chain(*module_partitions.values()))
for maxpool_partition in maxpool_partitions:
output_node = maxpool_partition.output_nodes[0]
maxpool_node = None
for n in maxpool_partition.nodes:
if n.target == torch.ops.aten.max_pool2d_with_indices.default:
maxpool_node = n
if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
continue
input_act = maxpool_node.args[0] # type: ignore[union-attr]
assert isinstance(input_act, Node)
act_qspec = get_input_act_qspec(quantization_config)
maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr]
input_qspec_map={
input_act: act_qspec,
},
_annotated=True,
)
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=SharedQuantizationSpec((input_act, maxpool_node)),
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
"""validate if the annotated graph is supported by the backend"""
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return []
def get_symmetric_quantization_config():
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
HistogramObserver
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver
extra_args: Dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
)
return quantization_config
if __name__ == "__main__":
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
quantizer = BackendQuantizer()
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
# Note: ``prepare_pt2e_quantizer`` will be updated to ``prepare_pt2e`` soon
m = prepare_pt2e_quantizer(m, quantizer)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
print("converted module is: {}".format(m), flush=True)
@aschabana
Copy link

ModuleNotFoundError: No module named 'torch.ao.quantization._pt2e' :(

@leslie-fang-intel
Copy link
Author

May need some update for the import path... Will take a look when got time.

@aschabana
Copy link

May need some update for the import path... Will take a look when got time.

Thankyou for quick response will wait for an update
:)

@Bam17
Copy link

Bam17 commented Mar 20, 2024

ModuleNotFoundError: No module named 'torch.ao.quantization._pt2e' :(

@aschabana
Forked the project and fixed the modules. You can access them here: https://gist.github.com/Bam17/ae66efffc85dec19e755ba2311fffefe

@aschabana
Copy link

Thankyou @Bam17 & @leslie-fang-intel :)

@singh20anurag
Copy link

singh20anurag commented May 7, 2024

I encountered this issue when running this script with pytorch 2.3.0+cu121:

Traceback (most recent call last):
  File "quantizing_resnet.py", line 258, in <module>
    m, guards = torchdynamo.export(
  File "/mypath/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 1443, in export
    return inner(*extra_args, **extra_kwargs)
  File "/mypath/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 1325, in inner
    dim_constraints.solve()
  File "/mypath/anaconda3/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1512, in solve
    tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
KeyError: "L['x'].size()[2]"

@leslie-fang-intel
Copy link
Author

Hi @singh20anurag, please refer to this tutorials: https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html?highlight=x86. For pytorch 2.3, we should use capture_pre_autograd_graph instead of torchdynamo.export to export the graph.

    exported_model = capture_pre_autograd_graph(
        model,
        example_inputs
    )

@singh20anurag
Copy link

Thanks! Was able to get past the issue. Wondering if modifying the Quantizer class is the only way for ExecuTorch to ingest quantized models or if it can also ingest models quantized with the FakeQuant approach?

@leslie-fang-intel
Copy link
Author

Thanks! Was able to get past the issue. Wondering if modifying the Quantizer class is the only way for ExecuTorch to ingest quantized models or if it can also ingest models quantized with the FakeQuant approach?

Hi @singh20anurag, thanks for the question. TBH, I am not familiar with the Executorch approach. ping @jerryzh168 @kimishpatel who may have a more precise answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment