Skip to content

Instantly share code, notes, and snippets.

@Wheest
Created June 7, 2022 16:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Wheest/42df546cedf084eaf8a4206c19a273b4 to your computer and use it in GitHub Desktop.
Save Wheest/42df546cedf084eaf8a4206c19a273b4 to your computer and use it in GitHub Desktop.
TVM quantize standalone
#!/usr/bin/env python
import tvm
from tvm import te
from tvm import relay
import mxnet as mx
from tvm.contrib.download import download_testdata
from mxnet import gluon
import logging
import os
import timeit
import numpy as np
print("Downloading test data...")
batch_size = 1
model_name = "resnet18_v1"
target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"
dev = tvm.device(target)
calibration_rec = download_testdata(
"http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/data/val_256_q90.rec",
"val_256_q90.rec",
)
print("Downloaded test data...")
def get_val_data(num_workers=4):
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch):
return batch.data[0].asnumpy(), batch.label[0].asnumpy()
img_size = 299 if model_name == "inceptionv3" else 224
val_data = mx.io.ImageRecordIter(
path_imgrec=calibration_rec,
preprocess_threads=num_workers,
shuffle=False,
batch_size=batch_size,
resize=256,
data_shape=(3, img_size, img_size),
mean_r=mean_rgb[0],
mean_g=mean_rgb[1],
mean_b=mean_rgb[2],
std_r=std_rgb[0],
std_g=std_rgb[1],
std_b=std_rgb[2],
)
return val_data, batch_fn
calibration_samples = 10
def calibrate_dataset():
val_data, batch_fn = get_val_data()
val_data.reset()
for i, batch in enumerate(val_data):
if i * batch_size >= calibration_samples:
break
data, _ = batch_fn(batch)
yield {"data": data}
def get_model():
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
img_size = 299 if model_name == "inceptionv3" else 224
data_shape = (batch_size, 3, img_size, img_size)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
return mod, params
def quantize(mod, params, mode="data_aware"):
if mode == "data_aware":
with relay.quantize.qconfig(calibrate_mode="kl_divergence", weight_scale="max"):
mod = relay.quantize.quantize(mod, params, dataset=calibrate_dataset())
elif mode == "global_scale":
with relay.quantize.qconfig(calibrate_mode="global_scale", global_scale=8.0):
mod = relay.quantize.quantize(mod, params)
elif mode == "power2":
with relay.quantize.qconfig(
calibrate_mode="global_scale", global_scale=8.0, weight_scale="power2"
):
mod = relay.quantize.quantize(mod, params)
else:
raise ValueError(f"Unknown mode {mode}")
return mod
def run_inference(mod):
modelr = relay.create_executor("graph", mod, dev, target)
model = modelr.evaluate()
val_data, batch_fn = get_val_data()
for i, batch in enumerate(val_data):
data, label = batch_fn(batch)
prediction = model(data)
if i > 10: # only run inference on a few samples in this tutorial
break
def benchmark(model):
timing_number = 10
timing_repeat = 10
val_data, batch_fn = get_val_data()
for i, batch in enumerate(val_data):
data, label = batch_fn(batch)
break
times = (
np.array(
timeit.Timer(lambda: model(data)).repeat(
repeat=timing_repeat, number=timing_number
)
)
* 1000
/ timing_number
)
times = {
"mean": np.mean(times),
"median": np.median(times),
"std": np.std(times),
}
return times
def test(mode="global_scale"):
mod, params = get_model()
mod = quantize(mod, params, mode=mode)
model = relay.create_executor("graph", mod, dev, target).evaluate()
times = benchmark(model)
print(f"For {mode}:", times)
return times
def test_normal():
mod, params = get_model()
with tvm.transform.PassContext(opt_level=3):
model = relay.build_module.create_executor(
"graph", mod, dev, target, params
).evaluate()
return benchmark(model)
def main():
times = dict()
modes = ["global_scale"] # , "data_aware", "power2"]
for m in modes:
times[m] = test(m)
times["normal"] = test_normal()
print(times)
for k, v in times.items():
print(k, v)
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment