Skip to content

Instantly share code, notes, and snippets.

@vkuzo
Created April 21, 2020 23:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 to your computer and use it in GitHub Desktop.
Save vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 to your computer and use it in GitHub Desktop.
import torch
from torch import nn, optim
from torch.quantization import QuantStub, DeQuantStub
from copy import deepcopy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(
QuantStub(),
nn.Conv2d(3, 1, 1, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(2),
nn.AvgPool2d(14),
nn.Sigmoid(),
DeQuantStub(),
)
torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
optimizer = optim.Adam(model.parameters(), lr=1)
model = nn.DataParallel(model, device_ids=[0, 1])
model.to(device)
print(model)
criterion = nn.BCELoss()
#model.apply(torch.quantization.disable_fake_quant)
for epoch in range(10):
print('EPOCH', epoch)
model.train()
inputs = torch.rand(2, 3, 28, 28)
# labels = torch.FloatTensor([[1,1,1,1,1,0,0,0,0,0], [1,1,1,1,1,0,0,0,0,0]])
labels = torch.FloatTensor([[1,1], [0,0]])
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
#loss = criterion(outputs.view(2, 10), labels)
loss = criterion(outputs.view(2, 2), labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 2:
model.apply(torch.quantization.disable_observer)
pass
if epoch >= 3:
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print('MODEL', model)
quant_model = deepcopy(model.module)
quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)
with torch.no_grad():
out = quant_model(torch.rand(1, 3, 28, 28))
print(out.view(2).tolist())
# print(out.view(10).tolist())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment