Skip to content

Instantly share code, notes, and snippets.

@vinx13
Last active August 16, 2019 06:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vinx13/6f1eb1f9e2c0a8786149ee881bfcd6aa to your computer and use it in GitHub Desktop.
Save vinx13/6f1eb1f9e2c0a8786149ee881bfcd6aa to your computer and use it in GitHub Desktop.
import numpy as np
import logging
import argparse
import os
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import tvm
import tvm.relay as relay
import tvm.relay.expr as _expr
import tvm.relay.transform as _transform
from tvm.contrib import graph_runtime
from scipy import stats
import pickle
import multiprocessing as mp
# Two functions for reading data from record file or raw images
def get_val_data(args,
rec_val,
batch_size,
num_workers=4,
shuffle=False):
rec_val = os.path.expanduser(rec_val)
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
return data, label
img_size = 299 if args.model == 'inceptionv3' else 224
val_data = mx.io.ImageRecordIter(
path_imgrec = rec_val,
preprocess_threads = num_workers,
shuffle = shuffle,
batch_size = batch_size,
resize = 256,
data_shape = (3, img_size, img_size),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
mean_b = mean_rgb[2],
std_r = std_rgb[0],
std_g = std_rgb[1],
std_b = std_rgb[2],
)
return val_data, batch_fn
def calibration_dataset():
val_data, batch_fn = get_val_data(args, args.rec_val, args.batch_size, shuffle=True)
val_data.reset()
for i, batch in enumerate(val_data):
if i*args.batch_size > args.calibration_samples:
break
data, label = batch_fn(batch, [mx.cpu(0)])
yield {'data': data[0].asnumpy()}
def evaluate(args, graph, lib, params, ctx):
"""Evaluate on the validation set."""
# setup dataset.
batch_size = args.batch_size
val_data, batch_fn = get_val_data(args, args.rec_val, batch_size)
# create runtime module
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
oshape = (batch_size, args.num_classes)
out_arr = tvm.nd.empty(oshape, "float32")
# setup evaluaiton metric
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
val_data.reset()
acc_top1.reset()
acc_top5.reset()
# Execute
for i, batch in enumerate(val_data):
data, label = batch_fn(batch, [mx.cpu(0)])
m.run(data=data[0].asnumpy())
m.get_output(0, out_arr)
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
if args.log_interval and not (i + 1) % args.log_interval:
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
nsamples = (i + 1) * batch_size
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
with open(args.record_file, "a") as f:
f.write('{}, {} / {}\n'.format(
args.model, top1, top5))
def calibrate_on_dataset(qgraph):
profile_graph = relay.quantize.collect_stats(qgraph)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(relay.Module.from_expr(profile_graph), target=args.target)
outputs = []
m = graph_runtime.create(graph, lib, tvm.context(args.target, args.device_id))
m.set_input(**params)
num_outputs = m.get_num_outputs()
outputs = [[] for i in range(num_outputs)]
for batch_id, batch in enumerate(calibration_dataset()):
print('batch {}..'.format(batch_id))
m.set_input(**batch)
m.run()
for i in range(num_outputs):
output = m.get_output(i).asnumpy()
outputs[i].append(output)
for i in range(num_outputs):
outputs[i] = np.concatenate(outputs[i]).reshape(-1)
with mp.Pool() as pool:
scales = list(pool.map(relay.quantize.kl_divergence.kl_divergence_scale, outputs))
return scales
def build_model(gluon_model, original):
"""Build with relay."""
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
img_size = 299 if args.model == 'inceptionv3' else 224
data_shape = (args.batch_size, 3, img_size, img_size)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
target = args.target
ctx = tvm.context(target, args.device_id)
if original:
# run original model
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
return graph, lib, params, ctx
skip_conv_layers = [0]
with relay.quantize.qconfig(store_lowbit_output=False, skip_conv_layers=skip_conv_layers):
from tvm.relay.quantize.quantize import _bind_params
graph = _bind_params(mod['main'], params)
mod = relay.Module.from_expr(graph)
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
with relay.build_config(opt_level=2):
mod = optimize(mod)
mod = relay.quantize.annotate()(mod)
cache_file = '{}_scales.pkl'.format(args.model)
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
scales = pickle.load(f)
else:
scales = calibrate_on_dataset(mod['main'])
with open(cache_file, 'wb') as f:
pickle.dump(scales, f)
if args.eval_power2:
scales = list(map(lambda scale: 2**np.math.ceil(np.math.log(scale, 2)) if scale > 0 else 1.0, scales))
weight_scales = 'power2'
else:
weight_scales = 'max'
mod['main'] = relay.quantize.calibrate(mod['main'], weight_scales=weight_scales,
scales=scales)
mod = relay.quantize.realize()(mod)
mod = relay.transform.FoldConstant()(mod)
graph, lib, params = relay.build(mod, target=args.target)
return graph, lib, params, ctx
def save_model(name, graph, lib, params):
with open(name + '.json', 'w') as f:
f.write(graph)
lib.export_library(name + '.so')
with open(name + '.bin', 'wb') as f:
f.write(relay.save_param_dict(params))
def main():
gluon_model = vision.get_model(args.model, pretrained=True)
graph, lib, params, ctx = build_model(gluon_model, args.original)
logging.info("Finish building model %s...", args.model)
evaluate(args, graph, lib, params, ctx)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy")
parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec",
help="the validation data")
parser.add_argument("--num-classes", type=int, default=1000,
help="batch size")
parser.add_argument("--model", type=str, default="resnet50_v2",
help="Name of the model")
parser.add_argument("--log-interval", type=int, default=100,
help="log interval")
parser.add_argument("--batch-size", type=int, default=1,
help="batch size")
parser.add_argument("--target", type=str, default="cuda",
help="target option")
parser.add_argument("--original", action="store_true",
help='whether to use original graph')
parser.add_argument('--save_model', type=str, default=None)
parser.add_argument('--calibration_samples', type=int, default=100)
parser.add_argument('--device-id', type=int, default=0)
parser.add_argument('--eval-power2', action='store_true',
help='in this mode, scales are restricted to power-of-2 (weight: power2' \
'scale, activation: round kld to power2)')
parser.add_argument('--record-file', type=str, default='record.csv',
help='file to save eval result')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.info(args)
main()
@mingwayzhang
Copy link

Hi,
I am wondering on which version of TVM is this script working on? I tried to run it but errors out saying that "relay.quantize does not have collect_stats" in Line 98. I did look at it, there is no such function at all in recent TVM.

So, if this one is obsolute, which one is the right script to evaluate the quantize effect? Thanks

@vinx13
Copy link
Author

vinx13 commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 2, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 14, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 14, 2019

@mingwayzhang it's ILSVRC2012 val

@mingwayzhang
Copy link

mingwayzhang commented Aug 14, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 14, 2019

@mingwayzhang what model are you using?

@mingwayzhang
Copy link

mingwayzhang commented Aug 15, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 15, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 16, 2019

@mingwayzhang I have updated line 149 and it works fine locally. There are some accuracy issues after #3543 is merged. I'm working on it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment