Skip to content

Instantly share code, notes, and snippets.

@skyline75489
Last active May 11, 2022 09:32
Show Gist options
  • Save skyline75489/7a8abb65c0f0f28dfcbb1407a07636fc to your computer and use it in GitHub Desktop.
Save skyline75489/7a8abb65c0f0f28dfcbb1407a07636fc to your computer and use it in GitHub Desktop.
import sys
import os
import traceback
import torch
import torch.utils.bundled_inputs
import torch.utils.mobile_optimizer
import torch.backends._nnapi.prepare
import torchvision.models.quantization.mobilenet
from pathlib import Path
# Bundle sample inputs with the models for easier benchmarking.
# This step is optional.
class BundleWrapper(torch.nn.Module):
def __init__(self, mod):
super().__init__()
self.mod = mod
def forward(self, arg):
return self.mod(arg)
def make_nnapi(model_name, quantize_mode):
quantize_core, quantize_iface = {
"none": (False, False),
"core": (True, False),
"full": (True, True),
}[quantize_mode]
model = getattr(torchvision.models.quantization, model_name)(pretrained=True, quantize=quantize_core)
model.eval()
# Fuse BatchNorm operators in the floating point model.
# (Quantized models already have this done.)
# Remove dropout for this inference-only use case.
if not quantize_core:
model.fuse_model()
if hasattr(model, 'classifier'):
#assert type(model.classifier[0]) == torch.nn.Dropout
model.classifier[0] = torch.nn.Identity()
input_float = torch.zeros(1, 3, 224, 224)
input_tensor = input_float
# Optimize the CPU model to make CPU-vs-NNAPI benchmarks fair.
cpu_model = torch.utils.mobile_optimizer.optimize_for_mobile(torch.jit.script(model))
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
cpu_model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)])
try:
# If we're doing a quantized model, we need to trace only the quantized core.
# So capture the quantizer and dequantizer, use them to prepare the input,
# and replace them with identity modules so we can trace without them.
if quantize_core:
quantizer = model.quant
dequantizer = model.dequant
model.quant = torch.nn.Identity()
model.dequant = torch.nn.Identity()
input_tensor = quantizer(input_float)
# Many NNAPI backends prefer NHWC tensors, so convert our input to channels_last,
# and set the "nnapi_nhwc" attribute for the converter.
input_tensor = input_tensor.contiguous(memory_format=torch.channels_last)
input_tensor.nnapi_nhwc = True
# Trace the model. NNAPI conversion only works with TorchScript models,
# and traced models are more likely to convert successfully than scripted.
with torch.no_grad():
traced = torch.jit.trace(model, input_tensor)
nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced, input_tensor)
# If we're not using a quantized interface, wrap a quant/dequant around the core.
if quantize_core and not quantize_iface:
nnapi_model = torch.nn.Sequential(quantizer, nnapi_model, dequantizer)
model.quant = quantizer
model.dequant = dequantizer
# Switch back to float input for benchmarking.
input_tensor = input_float.contiguous(memory_format=torch.channels_last)
nnapi_model = torch.jit.script(BundleWrapper(nnapi_model))
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
nnapi_model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)])
except Exception:
print(traceback.format_exc())
return (cpu_model, None)
return (cpu_model, nnapi_model)
model_list = ['resnet50']
for model_name in model_list:
output_dir_path = Path(os.environ["HOME"])
for quantize_mode in ["none", "core", "full"]:
model, nnapi_model = make_nnapi(model_name, quantize_mode)
# Save both models.
model._save_for_lite_interpreter(os.path.join(output_dir_path, ("{}_quant-{}-cpu.pt".format(model_name, quantize_mode))))
if nnapi_model is not None:
nnapi_model._save_for_lite_interpreter(os.path.join(output_dir_path, ("{}_quant-{}-nnapi.pt".format(model_name, quantize_mode))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment