Created
July 10, 2019 13:43
-
-
Save masahi/460223846b142b7fc01897143eb732df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import operator | |
import tvm | |
from tvm.contrib import graph_runtime | |
from tvm.relay.testing.config import ctx_list | |
from tvm import relay | |
import mxnet as mx | |
from mxnet import gluon | |
from mxnet.gluon.model_zoo import vision | |
import model_zoo | |
def verify_mxnet_frontend_impl(mx_symbol, | |
data_shape=(1, 3, 224, 224), | |
out_shape=(1, 1000), | |
gluon_impl=False, | |
name=None, | |
dtype='float32'): | |
"""Use name different from test to avoid let nose pick it up""" | |
if gluon_impl: | |
def get_gluon_output(name, x): | |
net = vision.get_model(name) | |
net.collect_params().initialize(mx.init.Xavier()) | |
net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')), | |
inputs=mx.sym.var('data'), | |
params=net.collect_params()) | |
out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy() | |
return out, net_sym | |
else: | |
def get_mxnet_output(symbol, x, dtype='float32'): | |
from collections import namedtuple | |
Batch = namedtuple('Batch', ['data']) | |
mod = mx.mod.Module(symbol, label_names=None) | |
mod.bind(data_shapes=[('data', x.shape)], for_training=False) | |
mod.init_params() | |
mod.forward(Batch([mx.nd.array(x.astype(dtype))])) | |
out = mod.get_outputs()[0].asnumpy() | |
args, auxs = mod.get_params() | |
return out, args, auxs | |
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): | |
shape_dict = {"data": x.shape} | |
if gluon_impl: | |
mod, params = relay.frontend.from_mxnet(symbol, shape_dict) | |
else: | |
mod, params = relay.frontend.from_mxnet(symbol, | |
shape_dict, | |
arg_params=args, | |
aux_params=auxs) | |
with relay.build_config(opt_level=3): | |
graph, lib, params = relay.build(mod, target, params=params) | |
m = graph_runtime.create(graph, lib, ctx) | |
# set inputs | |
m.set_input("data", tvm.nd.array(x.astype(dtype))) | |
m.set_input(**params) | |
m.run() | |
# get outputs | |
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)) | |
return out.asnumpy() | |
# random input | |
x = np.random.uniform(size=data_shape) | |
if gluon_impl: | |
gluon_out, gluon_sym = get_gluon_output(name, x) | |
for target, ctx in ctx_list(): | |
tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype) | |
tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5) | |
else: | |
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) | |
assert "data" not in args | |
for target in ["rocm"]: # "rocm -libs=miopen" | |
print("Target: ", target) | |
ctx = tvm.context(target, 0) | |
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) | |
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) | |
print("max abs diff: ", np.max(np.abs(mx_out - tvm_out))) | |
def test_forward_vgg(): | |
for n in [11]: | |
mx_sym = model_zoo.mx_vgg(n) | |
verify_mxnet_frontend_impl(mx_sym) | |
def test_forward_resnet(): | |
for n in [18]: | |
mx_sym = model_zoo.mx_resnet(18) | |
verify_mxnet_frontend_impl(mx_sym) | |
test_forward_vgg() # works | |
#test_forward_resnet() # HSA_STATUS_ERROR_INVALID_ISA (0x100f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment