Skip to content

Instantly share code, notes, and snippets.

@Wheest
Last active March 15, 2023 10:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Wheest/9dcbc00f28fa6554d124dce267d7b33e to your computer and use it in GitHub Desktop.
Save Wheest/9dcbc00f28fa6554d124dce267d7b33e to your computer and use it in GitHub Desktop.
Debug script for TVM int8 quantization
#!/usr/bin/env python3
import os
import time
import numpy as np
import torch
import tvm
from tvm import relay
from tvm.relay.transform import InferType, ToMixedPrecision
np.random.seed(42)
TEST_DATASETS = ["cifar10", "imagenet", "test"]
def quantize(mod, params):
with relay.quantize.qconfig(calibrate_mode="global_scale", global_scale=8.0):
mod = relay.quantize.quantize(mod, params)
return mod
def run_inference(mod, dev, target, in_shape):
model = relay.create_executor("vm", mod, dev, target)
model._make_executor()
model = model.evaluate()
data = np.random.uniform(5, 10, in_shape).astype(np.float32)
prediction = model(data)
@tvm.tir.transform.prim_func_pass(opt_level=0)
def print_tir(f, mod, ctx):
print(f)
return f
def run_inference(mod, dev, target, in_shape):
model = tvm.relay.create_executor("graph", mod, dev, target)
model._make_executor()
model = model.graph_module
model._make_executor()
data = np.random.uniform(5, 10, in_shape).astype(np.float32)
model.set_input(input_name, data)
model.run()
def run_inference_fp32(mod, params, input_name, dev, target, in_shape):
with tvm.transform.PassContext(
opt_level=3, config={"tir.add_lower_pass": [(3, print_tir)]}
):
lib = tvm.relay.build(mod, target=target, params=params)
model = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
data = np.random.uniform(5, 10, in_shape).astype(np.float32)
model.set_input(input_name, data)
model.run()
def model_opt(mod, params, run_fp16_pass=False, run_other_opts=True, fast_math=False):
# code adapted from https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/f1f9f698be2b7a8cc5bcf1167d892cd915eb7ce7/fp16_pass/benchmark_fp16.py#L19
mod = tvm.IRModule.from_expr(mod["main"])
remove_bn_pass = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
relay.transform.FoldScaleAxis(),
]
)
mod = remove_bn_pass(mod)
if run_other_opts:
mod = tvm.relay.transform.FastMath()(mod) if fast_math else mod
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod)
BindPass = tvm.relay.transform.function_pass(
lambda fn, new_mod, ctx: tvm.relay.build_module.bind_params_by_name(
fn, params
),
opt_level=1,
)
mod = BindPass(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
if run_fp16_pass:
mod = InferType()(mod)
mod = ToMixedPrecision()(mod)
if run_other_opts and run_fp16_pass:
# run one more pass to clean up new subgraph
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
mod = tvm.relay.transform.FastMath()(mod) if fast_math else mod
return mod, params
def main():
device = "x86_cpu"
if device == "x86_cpu":
target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"
dev = tvm.device(target)
elif device == "arm_cpu":
dev = tvm.cpu(0)
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
elif device == "arm_cuda":
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
target = tvm.target.Target("cuda", host=target)
dev = tvm.cuda(0)
else:
raise ValueError("Unknown device:", args.device)
model = torch.hub.load(
"pytorch/vision:v0.11.0", "densenet161", pretrained=False
).eval()
# model = torch.hub.load(
# "pytorch/vision:v0.11.0", "resnet50", pretrained=False
# ).eval()
# model = model_dict["densenet161-imagenet"]()
in_shape = [1, 3, 224, 224]
input_name = "input0"
input_data = torch.randn(in_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
shape_list = [(input_name, in_shape)]
mod, params = tvm.relay.frontend.from_pytorch(scripted_model, shape_list)
print(mod)
# exit(1)
# start = time.time()
# run_inference_fp32(mod, params, input_name, dev, target, in_shape)
# # run_tests(mod2, dev, target, test_data)
# end = time.time()
# print("fp32:", end - start)
print("loaded model")
mod, params = model_opt(mod, params)
print(mod)
mod2 = quantize(mod, params)
print("quantized")
start = time.time()
print(mod2)
run_inference(mod2, dev, target, in_shape)
end = time.time()
print(end - start)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment