Skip to content

Instantly share code, notes, and snippets.

@vinx13
Last active August 16, 2019 06:17
Show Gist options
  • Save vinx13/6f1eb1f9e2c0a8786149ee881bfcd6aa to your computer and use it in GitHub Desktop.
Save vinx13/6f1eb1f9e2c0a8786149ee881bfcd6aa to your computer and use it in GitHub Desktop.
import numpy as np
import logging
import argparse
import os
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import tvm
import tvm.relay as relay
import tvm.relay.expr as _expr
import tvm.relay.transform as _transform
from tvm.contrib import graph_runtime
from scipy import stats
import pickle
import multiprocessing as mp
# Two functions for reading data from record file or raw images
def get_val_data(args,
rec_val,
batch_size,
num_workers=4,
shuffle=False):
rec_val = os.path.expanduser(rec_val)
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
return data, label
img_size = 299 if args.model == 'inceptionv3' else 224
val_data = mx.io.ImageRecordIter(
path_imgrec = rec_val,
preprocess_threads = num_workers,
shuffle = shuffle,
batch_size = batch_size,
resize = 256,
data_shape = (3, img_size, img_size),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
mean_b = mean_rgb[2],
std_r = std_rgb[0],
std_g = std_rgb[1],
std_b = std_rgb[2],
)
return val_data, batch_fn
def calibration_dataset():
val_data, batch_fn = get_val_data(args, args.rec_val, args.batch_size, shuffle=True)
val_data.reset()
for i, batch in enumerate(val_data):
if i*args.batch_size > args.calibration_samples:
break
data, label = batch_fn(batch, [mx.cpu(0)])
yield {'data': data[0].asnumpy()}
def evaluate(args, graph, lib, params, ctx):
"""Evaluate on the validation set."""
# setup dataset.
batch_size = args.batch_size
val_data, batch_fn = get_val_data(args, args.rec_val, batch_size)
# create runtime module
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
oshape = (batch_size, args.num_classes)
out_arr = tvm.nd.empty(oshape, "float32")
# setup evaluaiton metric
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
val_data.reset()
acc_top1.reset()
acc_top5.reset()
# Execute
for i, batch in enumerate(val_data):
data, label = batch_fn(batch, [mx.cpu(0)])
m.run(data=data[0].asnumpy())
m.get_output(0, out_arr)
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
if args.log_interval and not (i + 1) % args.log_interval:
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
nsamples = (i + 1) * batch_size
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
with open(args.record_file, "a") as f:
f.write('{}, {} / {}\n'.format(
args.model, top1, top5))
def calibrate_on_dataset(qgraph):
profile_graph = relay.quantize.collect_stats(qgraph)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(relay.Module.from_expr(profile_graph), target=args.target)
outputs = []
m = graph_runtime.create(graph, lib, tvm.context(args.target, args.device_id))
m.set_input(**params)
num_outputs = m.get_num_outputs()
outputs = [[] for i in range(num_outputs)]
for batch_id, batch in enumerate(calibration_dataset()):
print('batch {}..'.format(batch_id))
m.set_input(**batch)
m.run()
for i in range(num_outputs):
output = m.get_output(i).asnumpy()
outputs[i].append(output)
for i in range(num_outputs):
outputs[i] = np.concatenate(outputs[i]).reshape(-1)
with mp.Pool() as pool:
scales = list(pool.map(relay.quantize.kl_divergence.kl_divergence_scale, outputs))
return scales
def build_model(gluon_model, original):
"""Build with relay."""
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
img_size = 299 if args.model == 'inceptionv3' else 224
data_shape = (args.batch_size, 3, img_size, img_size)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
target = args.target
ctx = tvm.context(target, args.device_id)
if original:
# run original model
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
return graph, lib, params, ctx
skip_conv_layers = [0]
with relay.quantize.qconfig(store_lowbit_output=False, skip_conv_layers=skip_conv_layers):
from tvm.relay.quantize.quantize import _bind_params
graph = _bind_params(mod['main'], params)
mod = relay.Module.from_expr(graph)
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
with relay.build_config(opt_level=2):
mod = optimize(mod)
mod = relay.quantize.annotate()(mod)
cache_file = '{}_scales.pkl'.format(args.model)
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
scales = pickle.load(f)
else:
scales = calibrate_on_dataset(mod['main'])
with open(cache_file, 'wb') as f:
pickle.dump(scales, f)
if args.eval_power2:
scales = list(map(lambda scale: 2**np.math.ceil(np.math.log(scale, 2)) if scale > 0 else 1.0, scales))
weight_scales = 'power2'
else:
weight_scales = 'max'
mod['main'] = relay.quantize.calibrate(mod['main'], weight_scales=weight_scales,
scales=scales)
mod = relay.quantize.realize()(mod)
mod = relay.transform.FoldConstant()(mod)
graph, lib, params = relay.build(mod, target=args.target)
return graph, lib, params, ctx
def save_model(name, graph, lib, params):
with open(name + '.json', 'w') as f:
f.write(graph)
lib.export_library(name + '.so')
with open(name + '.bin', 'wb') as f:
f.write(relay.save_param_dict(params))
def main():
gluon_model = vision.get_model(args.model, pretrained=True)
graph, lib, params, ctx = build_model(gluon_model, args.original)
logging.info("Finish building model %s...", args.model)
evaluate(args, graph, lib, params, ctx)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy")
parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec",
help="the validation data")
parser.add_argument("--num-classes", type=int, default=1000,
help="batch size")
parser.add_argument("--model", type=str, default="resnet50_v2",
help="Name of the model")
parser.add_argument("--log-interval", type=int, default=100,
help="log interval")
parser.add_argument("--batch-size", type=int, default=1,
help="batch size")
parser.add_argument("--target", type=str, default="cuda",
help="target option")
parser.add_argument("--original", action="store_true",
help='whether to use original graph')
parser.add_argument('--save_model', type=str, default=None)
parser.add_argument('--calibration_samples', type=int, default=100)
parser.add_argument('--device-id', type=int, default=0)
parser.add_argument('--eval-power2', action='store_true',
help='in this mode, scales are restricted to power-of-2 (weight: power2' \
'scale, activation: round kld to power2)')
parser.add_argument('--record-file', type=str, default='record.csv',
help='file to save eval result')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.info(args)
main()
@vinx13
Copy link
Author

vinx13 commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 2, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 2, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 14, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 14, 2019

@mingwayzhang it's ILSVRC2012 val

@mingwayzhang
Copy link

mingwayzhang commented Aug 14, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 14, 2019

@mingwayzhang what model are you using?

@mingwayzhang
Copy link

mingwayzhang commented Aug 15, 2019 via email

@mingwayzhang
Copy link

mingwayzhang commented Aug 15, 2019 via email

@vinx13
Copy link
Author

vinx13 commented Aug 16, 2019

@mingwayzhang I have updated line 149 and it works fine locally. There are some accuracy issues after #3543 is merged. I'm working on it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment