Skip to content

Instantly share code, notes, and snippets.

@vkuzo
Created May 31, 2022 12:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vkuzo/a9dbaf4d2e69d8289aa7d21d6d7e4e3c to your computer and use it in GitHub Desktop.
Save vkuzo/a9dbaf4d2e69d8289aa7d21d6d7e4e3c to your computer and use it in GitHub Desktop.
class conv_with_conv(nn.Module):
def __init__(self, input_dim, in_ch, out_ch, kernel_size, stride):
super().__init__()
self.shift_size = input_dim
scale, zero_point = 1e-4, 2
dtype = torch.qint8
float_state = torch.zeros(kernel_size-stride+input_dim, in_ch)
int_state = torch.quantize_per_tensor(float_state, scale, zero_point, dtype)
self.f_cat = nn.quantized.FloatFunctional()
self.register_buffer('internal_state', int_state)
self.conv = nn.Conv2d(in_ch, out_ch, (kernel_size, 1), stride)
def forward(self, x):
self.internal_state[:self.shift_size].data.copy_(
self.internal_state[self.shift_size:].clone())
print(1, self.internal_state[:-self.shift_size].unsqueeze(0).unsqueeze(-1).clone().dtype)
print(2, x.dtype)
x = self.f_cat.cat((
self.internal_state[:-self.shift_size].unsqueeze(0).unsqueeze(-1).clone(),
x))
x = self.internal_state.unsqueeze(0).transpose(1, 2).unsqueeze(-1)
x = self.conv(x)
return x
class test_model(nn.Module):
def __init__(self,):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.conv = conv_with_conv(2,40,1,3,1)
def forward(self,x):
x = self.quant(x)
x = self.conv(x)
x = self.dequant(x)
return x
model = test_model()
model.eval()
model.to('cpu')
dumy_input = torch.rand(1,2,40,1)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare(model,inplace=True)
torch.quantization.convert(model,inplace=True)
# jit_model = torch.jit.script(model)
# torchscript_model_optimized = optimize_for_mobile(jit_model)
# torchscript_model_optimized._save_for_lite_interpreter("test.ptl")
out = model(dumy_input)
print('quant out',out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment