Skip to content

Instantly share code, notes, and snippets.

@vkuzo
Created December 28, 2020 18:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vkuzo/61ac9744858509e175d4ce50258782e4 to your computer and use it in GitHub Desktop.
Save vkuzo/61ac9744858509e175d4ce50258782e4 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
self.relu1 = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu1(self.l1(x.view(x.size(0), -1)))
fq_weight = torch.quantization.FakeQuantize.with_args(\
observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
quant_min=0, quant_max=255, dtype=torch.quint8)
fq_activation = torch.quantization.FakeQuantize.with_args(\
observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
quant_min=0, quant_max=255, dtype=torch.quint8)
model = LeNet()
model.l1.qconfig = torch.quantization.QConfig(activation=fq_activation, weight=fq_weight)
torch.quantization.prepare_qat(model, inplace=True)
model.l1.apply(torch.quantization.disable_fake_quant)
class MyFakeQuantize(torch.quantization.FakeQuantize):
def __init__(self, observer, quant_min, quant_max, n_cluster=0, **observer_kwargs):
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
fq_weight = MyFakeQuantize.with_args(\
observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
quant_min=0, quant_max=255, dtype=torch.quint8)
fq_activation = MyFakeQuantize.with_args(\
observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
quant_min=0, quant_max=255, dtype=torch.quint8)
# from torch/quantization/fake_quantize.py
import re
def _is_fake_quant_script_module(mod):
''' Returns true if given mod is an instance of FakeQuantize script module.
'''
if isinstance(mod, torch.jit.RecursiveScriptModule):
# qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
suffix = mod._c.qualified_name.split('.', 1)[1]
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
return name == 'torch.quantization.fake_quantize.FakeQuantize'
return False
def custom_disable_fake_quant(mod):
if isinstance(mod, torch.quantization.fake_quantize.FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_fake_quant()
model2 = LeNet()
model2.l1.qconfig = torch.quantization.QConfig(activation=fq_activation, weight=fq_weight)
torch.quantization.prepare_qat(model2, inplace=True)
model2.l1.apply(custom_disable_fake_quant)
print(model2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment