Created
May 27, 2023 03:23
-
-
Save TadaoYamaoka/9db512cfd504d66c114263565eb2fbde to your computer and use it in GitHub Desktop.
Quantization Aware Training with FX Graph Mode
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
from torchvision import transforms | |
from torch.ao.quantization import get_default_qat_qconfig_mapping | |
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx | |
import torch.onnx | |
import time | |
import copy | |
import numpy as np | |
from torchvision.models import resnet18 | |
def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256): | |
train_transform = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) | |
]) | |
test_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) | |
]) | |
train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=train_transform) | |
test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=test_transform) | |
train_sampler = torch.utils.data.RandomSampler(train_set) | |
test_sampler = torch.utils.data.SequentialSampler(test_set) | |
train_loader = torch.utils.data.DataLoader( | |
dataset=train_set, batch_size=train_batch_size, | |
sampler=train_sampler, num_workers=num_workers) | |
test_loader = torch.utils.data.DataLoader( | |
dataset=test_set, batch_size=eval_batch_size, | |
sampler=test_sampler, num_workers=num_workers) | |
return train_loader, test_loader | |
def evaluate_model(model, test_loader, device, criterion=None): | |
model.eval() | |
model.to(device) | |
running_loss = 0 | |
running_corrects = 0 | |
for inputs, labels in test_loader: | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
outputs = model(inputs) | |
_, preds = torch.max(outputs, 1) | |
if criterion is not None: | |
loss = criterion(outputs, labels).item() | |
else: | |
loss = 0 | |
# statistics | |
running_loss += loss * inputs.size(0) | |
running_corrects += torch.sum(preds == labels.data) | |
eval_loss = running_loss / len(test_loader.dataset) | |
eval_accuracy = running_corrects / len(test_loader.dataset) | |
return eval_loss, eval_accuracy | |
def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200): | |
# The training configurations were not carefully selected. | |
criterion = nn.CrossEntropyLoss() | |
model.to(device) | |
# It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10. | |
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4) | |
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500) | |
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1) | |
# Evaluation | |
model.eval() | |
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion) | |
print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy)) | |
for epoch in range(num_epochs): | |
# Training | |
model.train() | |
running_loss = 0 | |
running_corrects = 0 | |
for inputs, labels in train_loader: | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
outputs = model(inputs) | |
_, preds = torch.max(outputs, 1) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# statistics | |
running_loss += loss.item() * inputs.size(0) | |
running_corrects += torch.sum(preds == labels.data) | |
train_loss = running_loss / len(train_loader.dataset) | |
train_accuracy = running_corrects / len(train_loader.dataset) | |
# Evaluation | |
model.eval() | |
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion) | |
# Set learning rate scheduler | |
scheduler.step() | |
print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy)) | |
return model | |
def measure_inference_latency(model, | |
device, | |
input_size=(1, 3, 32, 32), | |
num_samples=100, | |
num_warmups=10): | |
model.to(device) | |
model.eval() | |
x = torch.rand(size=input_size).to(device) | |
with torch.no_grad(): | |
for _ in range(num_warmups): | |
_ = model(x) | |
torch.cuda.synchronize() | |
with torch.no_grad(): | |
start_time = time.time() | |
for _ in range(num_samples): | |
_ = model(x) | |
torch.cuda.synchronize() | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
elapsed_time_ave = elapsed_time / num_samples | |
return elapsed_time_ave | |
def save_model(model, model_dir, model_filename): | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
model_filepath = os.path.join(model_dir, model_filename) | |
torch.save(model.state_dict(), model_filepath) | |
def load_model(model, model_filepath, device): | |
model.load_state_dict(torch.load(model_filepath, map_location=device)) | |
return model | |
def save_torchscript_model(model, model_dir, model_filename): | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
model_filepath = os.path.join(model_dir, model_filename) | |
torch.jit.save(torch.jit.script(model), model_filepath) | |
def load_torchscript_model(model_filepath, device): | |
model = torch.jit.load(model_filepath, map_location=device) | |
return model | |
def model_equivalence(model_1, model_2, device, rtol=1e-05, atol=1e-08, num_tests=100, input_size=(1,3,32,32)): | |
model_1.to(device) | |
model_2.to(device) | |
for _ in range(num_tests): | |
x = torch.rand(size=input_size).to(device) | |
y1 = model_1(x).detach().cpu().numpy() | |
y2 = model_2(x).detach().cpu().numpy() | |
if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol, equal_nan=False) == False: | |
print("Model equivalence test sample failed: ") | |
print(y1) | |
print(y2) | |
return False | |
return True | |
def main(): | |
num_classes = 10 | |
cuda_device = torch.device("cuda:0") | |
cpu_device = torch.device("cpu:0") | |
model_dir = "saved_models" | |
model_filename = "resnet18_cifar10.pt" | |
quantized_model_filename = "resnet18_quantized_cifar10.pt" | |
model_filepath = os.path.join(model_dir, model_filename) | |
quantized_model_filepath = os.path.join(model_dir, quantized_model_filename) | |
# Create an untrained model. | |
model = model = resnet18(num_classes=num_classes, pretrained=False) | |
train_loader, test_loader = prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256) | |
# Train model. | |
print("Training Model...") | |
model = train_model(model=model, train_loader=train_loader, test_loader=test_loader, device=cuda_device, learning_rate=1e-1, num_epochs=100) | |
# Save model. | |
save_model(model=model, model_dir=model_dir, model_filename=model_filename) | |
# Prepare a model for quantization aware training | |
model.to(cpu_device) | |
model_to_quantize = copy.deepcopy(model) | |
qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm") | |
example_inputs = torch.rand(size=(1,3,32,32)).to(cpu_device) | |
prepared_model = prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs) | |
# Print FP32 model. | |
print(model) | |
# Print fused model. | |
print(prepared_model) | |
# Model and fused model should be equivalent. | |
model.eval() | |
prepared_model.eval() | |
assert model_equivalence(model_1=model, model_2=prepared_model, device=cpu_device, rtol=1e-01, atol=3, num_tests=100, input_size=(1,3,32,32)), "Fused model is not equivalent to the original model!" | |
# Quantization aware training | |
print("Training QAT Model...") | |
prepared_model.train() | |
train_model(model=prepared_model, train_loader=train_loader, test_loader=test_loader, device=cuda_device, learning_rate=1e-3, num_epochs=10) | |
prepared_model.to(cpu_device) | |
# Convert trained model to quantized model | |
quantized_model = convert_fx(prepared_model) | |
quantized_model.eval() | |
# Print quantized model. | |
print(quantized_model) | |
# Save quantized model. | |
save_torchscript_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename) | |
# Load quantized model. | |
quantized_jit_model = load_torchscript_model(model_filepath=quantized_model_filepath, device=cpu_device) | |
_, fp32_eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=cpu_device, criterion=None) | |
_, int8_eval_accuracy = evaluate_model(model=quantized_jit_model, test_loader=test_loader, device=cpu_device, criterion=None) | |
print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy)) | |
print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy)) | |
fp32_cpu_inference_latency = measure_inference_latency(model=model, device=cpu_device, input_size=(1,3,32,32), num_samples=100) | |
int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100) | |
int8_jit_cpu_inference_latency = measure_inference_latency(model=quantized_jit_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100) | |
fp32_gpu_inference_latency = measure_inference_latency(model=model, device=cuda_device, input_size=(1,3,32,32), num_samples=100) | |
print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(fp32_cpu_inference_latency * 1000)) | |
print("FP32 CUDA Inference Latency: {:.2f} ms / sample".format(fp32_gpu_inference_latency * 1000)) | |
print("INT8 CPU Inference Latency: {:.2f} ms / sample".format(int8_cpu_inference_latency * 1000)) | |
print("INT8 JIT CPU Inference Latency: {:.2f} ms / sample".format(int8_jit_cpu_inference_latency * 1000)) | |
# Export ONNX | |
model.to(cpu_device) | |
torch.onnx.export(model, example_inputs, os.path.splitext(model_filepath)[0] + ".onnx", | |
verbose = True, | |
do_constant_folding = True, | |
input_names = ['input'], | |
output_names = ['output'], | |
dynamic_axes={ | |
'input' : {0 : 'batch_size'}, | |
'output' : {0 : 'batch_size'}, | |
}) | |
torch.onnx.export(quantized_model, example_inputs, os.path.splitext(quantized_model_filepath)[0] + ".onnx", | |
verbose = True, | |
do_constant_folding = True, | |
input_names = ['input'], | |
output_names = ['output'], | |
dynamic_axes={ | |
'input' : {0 : 'batch_size'}, | |
'output' : {0 : 'batch_size'}, | |
}) | |
if __name__ == "__main__": | |
main() |
I have question regarding using this for Mobilenet, I will always get significant loss in accuracy compared to the resnet, I know its because of the depth wise separable layers, My question is can we skip this layer during the QAT?
@TadaoYamaoka Nice Work!
Can you please elaborate, why the equivalence criteria is pretty lenient?
"rtol=1e-01, atol=3"
Japanese Translation:
ナイスワーク!
なぜ同等性基準がかなり緩やかなのか、詳しく説明していただけますか。
「rtol=1e-01、atol=3」
There's no deep reason, you can make your experiments more stringent.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
nice work, thanks man