Skip to content

Instantly share code, notes, and snippets.

@TadaoYamaoka
Created May 27, 2023 03:23
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save TadaoYamaoka/9db512cfd504d66c114263565eb2fbde to your computer and use it in GitHub Desktop.
Save TadaoYamaoka/9db512cfd504d66c114263565eb2fbde to your computer and use it in GitHub Desktop.
Quantization Aware Training with FX Graph Mode
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
import torch.onnx
import time
import copy
import numpy as np
from torchvision.models import resnet18
def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256):
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=test_transform)
train_sampler = torch.utils.data.RandomSampler(train_set)
test_sampler = torch.utils.data.SequentialSampler(test_set)
train_loader = torch.utils.data.DataLoader(
dataset=train_set, batch_size=train_batch_size,
sampler=train_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(
dataset=test_set, batch_size=eval_batch_size,
sampler=test_sampler, num_workers=num_workers)
return train_loader, test_loader
def evaluate_model(model, test_loader, device, criterion=None):
model.eval()
model.to(device)
running_loss = 0
running_corrects = 0
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
if criterion is not None:
loss = criterion(outputs, labels).item()
else:
loss = 0
# statistics
running_loss += loss * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
eval_loss = running_loss / len(test_loader.dataset)
eval_accuracy = running_corrects / len(test_loader.dataset)
return eval_loss, eval_accuracy
def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):
# The training configurations were not carefully selected.
criterion = nn.CrossEntropyLoss()
model.to(device)
# It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1)
# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))
for epoch in range(num_epochs):
# Training
model.train()
running_loss = 0
running_corrects = 0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
train_loss = running_loss / len(train_loader.dataset)
train_accuracy = running_corrects / len(train_loader.dataset)
# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
# Set learning rate scheduler
scheduler.step()
print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))
return model
def measure_inference_latency(model,
device,
input_size=(1, 3, 32, 32),
num_samples=100,
num_warmups=10):
model.to(device)
model.eval()
x = torch.rand(size=input_size).to(device)
with torch.no_grad():
for _ in range(num_warmups):
_ = model(x)
torch.cuda.synchronize()
with torch.no_grad():
start_time = time.time()
for _ in range(num_samples):
_ = model(x)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time
elapsed_time_ave = elapsed_time / num_samples
return elapsed_time_ave
def save_model(model, model_dir, model_filename):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_filepath = os.path.join(model_dir, model_filename)
torch.save(model.state_dict(), model_filepath)
def load_model(model, model_filepath, device):
model.load_state_dict(torch.load(model_filepath, map_location=device))
return model
def save_torchscript_model(model, model_dir, model_filename):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_filepath = os.path.join(model_dir, model_filename)
torch.jit.save(torch.jit.script(model), model_filepath)
def load_torchscript_model(model_filepath, device):
model = torch.jit.load(model_filepath, map_location=device)
return model
def model_equivalence(model_1, model_2, device, rtol=1e-05, atol=1e-08, num_tests=100, input_size=(1,3,32,32)):
model_1.to(device)
model_2.to(device)
for _ in range(num_tests):
x = torch.rand(size=input_size).to(device)
y1 = model_1(x).detach().cpu().numpy()
y2 = model_2(x).detach().cpu().numpy()
if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol, equal_nan=False) == False:
print("Model equivalence test sample failed: ")
print(y1)
print(y2)
return False
return True
def main():
num_classes = 10
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")
model_dir = "saved_models"
model_filename = "resnet18_cifar10.pt"
quantized_model_filename = "resnet18_quantized_cifar10.pt"
model_filepath = os.path.join(model_dir, model_filename)
quantized_model_filepath = os.path.join(model_dir, quantized_model_filename)
# Create an untrained model.
model = model = resnet18(num_classes=num_classes, pretrained=False)
train_loader, test_loader = prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256)
# Train model.
print("Training Model...")
model = train_model(model=model, train_loader=train_loader, test_loader=test_loader, device=cuda_device, learning_rate=1e-1, num_epochs=100)
# Save model.
save_model(model=model, model_dir=model_dir, model_filename=model_filename)
# Prepare a model for quantization aware training
model.to(cpu_device)
model_to_quantize = copy.deepcopy(model)
qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
example_inputs = torch.rand(size=(1,3,32,32)).to(cpu_device)
prepared_model = prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# Print FP32 model.
print(model)
# Print fused model.
print(prepared_model)
# Model and fused model should be equivalent.
model.eval()
prepared_model.eval()
assert model_equivalence(model_1=model, model_2=prepared_model, device=cpu_device, rtol=1e-01, atol=3, num_tests=100, input_size=(1,3,32,32)), "Fused model is not equivalent to the original model!"
# Quantization aware training
print("Training QAT Model...")
prepared_model.train()
train_model(model=prepared_model, train_loader=train_loader, test_loader=test_loader, device=cuda_device, learning_rate=1e-3, num_epochs=10)
prepared_model.to(cpu_device)
# Convert trained model to quantized model
quantized_model = convert_fx(prepared_model)
quantized_model.eval()
# Print quantized model.
print(quantized_model)
# Save quantized model.
save_torchscript_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename)
# Load quantized model.
quantized_jit_model = load_torchscript_model(model_filepath=quantized_model_filepath, device=cpu_device)
_, fp32_eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=cpu_device, criterion=None)
_, int8_eval_accuracy = evaluate_model(model=quantized_jit_model, test_loader=test_loader, device=cpu_device, criterion=None)
print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy))
print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy))
fp32_cpu_inference_latency = measure_inference_latency(model=model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
int8_jit_cpu_inference_latency = measure_inference_latency(model=quantized_jit_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
fp32_gpu_inference_latency = measure_inference_latency(model=model, device=cuda_device, input_size=(1,3,32,32), num_samples=100)
print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(fp32_cpu_inference_latency * 1000))
print("FP32 CUDA Inference Latency: {:.2f} ms / sample".format(fp32_gpu_inference_latency * 1000))
print("INT8 CPU Inference Latency: {:.2f} ms / sample".format(int8_cpu_inference_latency * 1000))
print("INT8 JIT CPU Inference Latency: {:.2f} ms / sample".format(int8_jit_cpu_inference_latency * 1000))
# Export ONNX
model.to(cpu_device)
torch.onnx.export(model, example_inputs, os.path.splitext(model_filepath)[0] + ".onnx",
verbose = True,
do_constant_folding = True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={
'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'},
})
torch.onnx.export(quantized_model, example_inputs, os.path.splitext(quantized_model_filepath)[0] + ".onnx",
verbose = True,
do_constant_folding = True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={
'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'},
})
if __name__ == "__main__":
main()
@chamecall
Copy link

nice work, thanks man

@saeedadeeb103
Copy link

saeedadeeb103 commented Oct 22, 2023

I have question regarding using this for Mobilenet, I will always get significant loss in accuracy compared to the resnet, I know its because of the depth wise separable layers, My question is can we skip this layer during the QAT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment