Created
October 31, 2022 21:38
-
-
Save vkuzo/21c1ae37a262696faa5914843187426a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.fx | |
import torch.nn.functional as F | |
import torch.ao.quantization.quantize_fx as quantize_fx | |
from torch.quantization.quantize_fx import prepare_fx, convert_fx | |
# workaround for https://discuss.pytorch.org/t/dequantize-index-1-exceeded-reference-nodes-arg-length-1/164 | |
445 | |
class M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.w = torch.nn.Parameter(torch.randn(1, 1)) | |
def forward(self, x): | |
x = F.linear(input=x, weight=self.w) | |
return x | |
# this will fail | |
# m = M() | |
# mp = prepare_fx(m, torch.ao.quantization.get_default_qconfig_mapping('fbgemm'), (torch.randn(1, 1),)) | |
# mq = convert_fx(mp) | |
def normalize_conv_linear(m): | |
mt = torch.fx.symbolic_trace(m) | |
for n in mt.graph.nodes: | |
if n.op == 'call_function' and n.target in (F.linear, F.conv1d, F.conv2d, F.conv3d): | |
print(n.format_node()) | |
norm_args, norm_kwargs = n.normalized_arguments(mt) | |
print(norm_args, norm_kwargs) | |
new_args_list = list(n.args) | |
new_kwargs = norm_kwargs | |
if len(new_args_list) == 0: | |
new_args_list.append(new_kwargs['input']) | |
del new_kwargs['input'] | |
if len(new_args_list) == 1: | |
new_args_list.append(new_kwargs['weight']) | |
del new_kwargs['weight'] | |
if len(new_args_list) == 2 and 'bias' in new_kwargs: | |
new_args_list.append(new_kwargs['bias']) | |
del new_kwargs['bias'] | |
n.args = tuple(new_args_list) | |
n.kwargs = new_kwargs | |
mt.recompile() | |
mt.graph.lint() | |
return mt | |
m = M() | |
mn = normalize_conv_linear(m) | |
mp = prepare_fx( | |
mn, torch.ao.quantization.get_default_qconfig_mapping('fbgemm'), | |
(torch.randn(1, 1),)) | |
mq = convert_fx(mp) | |
print(mq) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment