Skip to content

Instantly share code, notes, and snippets.

@vkuzo
Created October 31, 2022 21:38
Show Gist options
  • Save vkuzo/21c1ae37a262696faa5914843187426a to your computer and use it in GitHub Desktop.
Save vkuzo/21c1ae37a262696faa5914843187426a to your computer and use it in GitHub Desktop.
import torch
import torch.fx
import torch.nn.functional as F
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.quantization.quantize_fx import prepare_fx, convert_fx
# workaround for https://discuss.pytorch.org/t/dequantize-index-1-exceeded-reference-nodes-arg-length-1/164
445
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.randn(1, 1))
def forward(self, x):
x = F.linear(input=x, weight=self.w)
return x
# this will fail
# m = M()
# mp = prepare_fx(m, torch.ao.quantization.get_default_qconfig_mapping('fbgemm'), (torch.randn(1, 1),))
# mq = convert_fx(mp)
def normalize_conv_linear(m):
mt = torch.fx.symbolic_trace(m)
for n in mt.graph.nodes:
if n.op == 'call_function' and n.target in (F.linear, F.conv1d, F.conv2d, F.conv3d):
print(n.format_node())
norm_args, norm_kwargs = n.normalized_arguments(mt)
print(norm_args, norm_kwargs)
new_args_list = list(n.args)
new_kwargs = norm_kwargs
if len(new_args_list) == 0:
new_args_list.append(new_kwargs['input'])
del new_kwargs['input']
if len(new_args_list) == 1:
new_args_list.append(new_kwargs['weight'])
del new_kwargs['weight']
if len(new_args_list) == 2 and 'bias' in new_kwargs:
new_args_list.append(new_kwargs['bias'])
del new_kwargs['bias']
n.args = tuple(new_args_list)
n.kwargs = new_kwargs
mt.recompile()
mt.graph.lint()
return mt
m = M()
mn = normalize_conv_linear(m)
mp = prepare_fx(
mn, torch.ao.quantization.get_default_qconfig_mapping('fbgemm'),
(torch.randn(1, 1),))
mq = convert_fx(mp)
print(mq)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment