Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Last active March 29, 2024 00:13
Show Gist options
  • Save apivovarov/f9458154aa17ae2bbd5edf5ca58f8703 to your computer and use it in GitHub Desktop.
Save apivovarov/f9458154aa17ae2bbd5edf5ca58f8703 to your computer and use it in GitHub Desktop.
TensorRT Quantize Resnet50 TRT

Quantize Resnet50 model with TensorRT

Intro

TensorRT supports two approaches to prepare model for Quantization - Calibration or Training

First we need to add/replace regular model nn.Layers with TRT pytorch_quantization.nn layers. Quantization layers will gather statistics required for quantization.

Once the model is modified we can use the following approaches to gather statistics before quantization:

  1. Calibrate pre-trainer model
  2. Train (1 epoch) pre-trainer model

Resulting model should be exported to ONNX

ONNX model can be converted to int8 TRT engine

Start NVIDIA PyTorch/TensorRT container

docker run -ti -v ~/workspace:/root/workspace \
  --gpus all --name py2301 \
  --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
  nvcr.io/nvidia/pytorch:23.01-py3

Install additional packages

pip3 install -U pip
pip3 install pycuda torchinfo

Get test image

curl -L -o cat.jpg https://i.ibb.co/tXK0D91/Screen-Shot-2023-02-07-at-12-11-08-PM.jpg

Add torchvision/references/classification to PYTHONPATH

cd ~/workspace
git clone https://github.com/pytorch/vision.git torchvision
cd torchvision/references/classification
export PYTHONPATH=$PWD
cd ~/workspace

Get ImageNet 1000 val and train datasets

cd ~/workspace && mkdir datasets && cd datasets
# Train dataset - 138 GB
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar --no-check-certificate
# Validation dataset - 6.3 GB
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificate

# Alternatievly we can use [academictorrents.com](academictorrents.com) to get datasets using torrents
# Train
wget https://academictorrents.com/download/a306397ccf9c2ead27155983c254227c0fd938e2.torrent
transmission-cli a306397ccf9c2ead27155983c254227c0fd938e2.torrent
mv Downloads/ILSVRC2012_img_train.tar .
# Validation
wget https://academictorrents.com/download/5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5.torrent
transmission-cli 5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5.torrent
mv Downloads/ILSVRC2012_img_val.tar .

# To extract and prepare Train Dataset
mkdir -p imagenet/train
cd imagenet/train
tar -xvf ../../ILSVRC2012_img_train.tar
# At this stage imagenet/train will contain 1000 compressed .tar files, one for each category
# Lets extract them to corresponding folders
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ~/workspace/datasets

# To extract and prepare Validation Dataset
mkdir -p imagenet/val
cd imagenet/val
tar -xvf ../../ILSVRC2012_img_val.tar
# get script from soumith and run
# this script creates all class directories and moves images into corresponding directories
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
cd ~/workspace

Prepare Training and Validation Dataset

from pathlib import Path
import os
import torch
from torchvision import datasets
from torchvision import transforms

dataset_dir = os.path.join(Path.home(), "workspace/datasets/imagenet")
batch_size = 64
kwargs = {"num_workers": 4, "pin_memory": True}
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Training
train_dir = os.path.join(dataset_dir, "train")
train_trans = [
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    norm,
]
train_data = datasets.ImageFolder(train_dir, transform=transforms.Compose(train_trans))
# Prepare smaller train dataset (10% of the original size) to speedup the process
train_data_small, _ = torch.utils.data.random_split(
    train_data, [0.1, 0.9], generator=torch.Generator().manual_seed(42)
)
train_data_loader = torch.utils.data.DataLoader(
    train_data_small, batch_size=batch_size, shuffle=True, **kwargs
)

# Validation
val_dir = os.path.join(dataset_dir, "val")
val_trans = [
    transforms.Resize(232),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    norm,
]
val_data = datasets.ImageFolder(val_dir, transform=transforms.Compose(val_trans))
val_data_loader = torch.utils.data.DataLoader(
    val_data, batch_size=batch_size, shuffle=True, **kwargs
)

Prepare the model

Prepare special Resnet50 model where normall nn Layers are automatically replaced with TRT pytorch_quantization.nn layers (QuantConv2d).

import torch
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
from torchinfo import summary

from tqdm import tqdm

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

# Set default QuantDescriptor to use histogram based calibration for activation
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

# Now new models will automatically have QuantConv2d layers instead of regular Conv2d
from pytorch_quantization import quant_modules
quant_modules.initialize() 

img = read_image("cat.jpg")

weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)

model = resnet50(weights=weights)
model=model.eval()

summary(model, batch.shape) # make sure the model consist of QuantConv2d layers

model=model.cuda()
batch=batch.cuda()

prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

Test Accuracy

from train import evaluate
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
    evaluate(model, criterion, val_data_loader, device="cuda", print_freq=20)

Calibrate the model

Switch the model to calibration mode and feed data into it

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()
    progress_bar=tqdm(data_loader, total=num_batches, desc='Calibrate')
    for i, (data, target) in enumerate(progress_bar):
        prediction=model(data.cuda())
        if i >= num_batches:
            break
    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()


def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
                #print(F"{name:40}: {module}")
    model.cuda()
with torch.no_grad():
    collect_stats(model, train_data_loader, num_batches=20)

We can try different calibrations and see which one works the best

with torch.no_grad():
    print("percentile 99.99 calibration")
    compute_amax(model, method="percentile", percentile=99.99)
    evaluate(model, criterion, val_data_loader, device="cuda", print_freq=20)
with torch.no_grad():
    print("percentile 99.9 calibration")
    compute_amax(model, method="percentile", percentile=99.9)
    evaluate(model, criterion, val_data_loader, device="cuda", print_freq=20)
with torch.no_grad():
    method="entropy"
    print(F"{method} calibration")
    compute_amax(model, method=method)
    evaluate(model, criterion, val_data_loader, device="cuda", print_freq=20)
with torch.no_grad():
    method="mse"
    print(F"{method} calibration")
    compute_amax(model, method=method)
    evaluate(model, criterion, val_data_loader, device="cuda", print_freq=20)

Save weights

# Save calibrated model state dictionary
torch.save(model.state_dict(), "quant_resnet50-calibrated.pth")

Quantization Aware Training

We can fine-tune the calibrated model to improve accuracy further.

Load calibrated model (if needed)

import torch
from torchvision.models import resnet50

model = resnet50()
model.load_state_dict(torch.load("quant_resnet50-calibrated.pth"))
model=model.cuda()

Train

from train import train_one_epoch

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

class args:
  print_freq=10
  clip_grad_norm=None
  model_ema_steps=32
  lr_warmup_epochs=0

train_one_epoch(model, criterion, optimizer, train_data_loader, "cuda", 0, args)

# Save the model
torch.save(model.state_dict(), "quant_resnet50-finetuned.pth")

Export to ONNX

from pytorch_quantization import nn as quant_nn

quant_nn.TensorQuantizer.use_fb_fake_quant = True

dummy_input = torch.randn(1, 3, 224, 224, device="cuda")
input_names = ["input0"]
output_names = ["output0"]
dynamic_axes = {"input0": {0: "batch"}, "output0": {0: "batch"}}

# Sets the model to inference mode - train(False)
model = model.eval()
y = model(dummy_input)
torch.onnx.export(
    model,
    dummy_input,
    "quant_resnet50.onnx",
    verbose=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
)

Build TRT Engine

trtexec \
--int8 \
--verbose \
--onnx=quant_resnet50.onnx \
--saveEngine=quant_resnet50.trt \
--minShapes=input0:1x3x224x224 \
--optShapes=input0:8x3x224x224 \
--maxShapes=input0:16x3x224x224

Run TRT Engine using Python API

==== Run TRT Engine using python API ====
import numpy as np
from torchvision.io import read_image
from torchvision.models import ResNet50_Weights
img = read_image("cat.jpg")
preprocess = ResNet50_Weights.DEFAULT.transforms()
batch = preprocess(img).unsqueeze(0)
batch = batch.numpy()
batch = np.concatenate([batch]*8)

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import sys
import numpy as np
trt_logger = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(trt_logger)
fpath="quant_resnet50.trt"

with open(fpath, "rb") as f:
    engine = runtime.deserialize_cuda_engine(f.read())

context = engine.create_execution_context()
BATCH_SIZE = 8
context.set_input_shape("input0", (BATCH_SIZE, 3, 224, 224))

print("Engine Info:")
for i, binding in enumerate(engine):
    shape = [engine.max_batch_size, *engine.get_binding_shape(binding)]
    dtype = trt.nptype(engine.get_binding_dtype(binding))
    volume = abs(trt.volume(engine.get_binding_shape(binding)))
    if engine.binding_is_input(binding):
        desc = "input"
    else:
        desc = "output"
    print(f"{i} type:    {desc}\n  binding: {binding} \n  data:    {np.dtype(dtype).name}\n  shape:   {shape} => {volume} \n")

USE_FP16 = False
target_dtype = np.float16 if USE_FP16 else np.float32

output = np.empty([BATCH_SIZE, 1000], dtype = target_dtype)

# allocate device memory
d_input = cuda.mem_alloc(1 * batch.nbytes)
d_output = cuda.mem_alloc(1 * output.nbytes)
bindings = [int(d_input), int(d_output)]
stream = cuda.Stream()

def predict(batch): # result gets copied into output
    # transfer input data to device
    cuda.memcpy_htod_async(d_input, batch, stream)
    # execute model
    context.execute_async_v2(bindings, stream.handle, None)
    # transfer predictions back
    cuda.memcpy_dtoh_async(output, d_output, stream)
    # syncronize threads
    stream.synchronize()

predict(batch)

best_ids=np.argmax(output,axis=-1)
print("Best class ids:", best_ids)

# Warmup
for i in range(100):
  predict(batch)

# Measure Latency
import time
TT=[]
for i in range(100):
  t0=time.time()
  predict(batch)
  t1=time.time()
  TT.append((t1-t0)*1000/BATCH_SIZE)

print("AVG time (ms):",np.mean(TT))
print("P50 time (ms):",np.percentile(TT, 50))
print("P95 time (ms):",np.percentile(TT, 95))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment