Skip to content

Instantly share code, notes, and snippets.

@masahi
Created July 10, 2019 13:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masahi/460223846b142b7fc01897143eb732df to your computer and use it in GitHub Desktop.
Save masahi/460223846b142b7fc01897143eb732df to your computer and use it in GitHub Desktop.
import numpy as np
import operator
import tvm
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm import relay
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import model_zoo
def verify_mxnet_frontend_impl(mx_symbol,
data_shape=(1, 3, 224, 224),
out_shape=(1, 1000),
gluon_impl=False,
name=None,
dtype='float32'):
"""Use name different from test to avoid let nose pick it up"""
if gluon_impl:
def get_gluon_output(name, x):
net = vision.get_model(name)
net.collect_params().initialize(mx.init.Xavier())
net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
inputs=mx.sym.var('data'),
params=net.collect_params())
out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
return out, net_sym
else:
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
mod = mx.mod.Module(symbol, label_names=None)
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
mod.init_params()
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
out = mod.get_outputs()[0].asnumpy()
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
shape_dict = {"data": x.shape}
if gluon_impl:
mod, params = relay.frontend.from_mxnet(symbol, shape_dict)
else:
mod, params = relay.frontend.from_mxnet(symbol,
shape_dict,
arg_params=args,
aux_params=auxs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("data", tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
# random input
x = np.random.uniform(size=data_shape)
if gluon_impl:
gluon_out, gluon_sym = get_gluon_output(name, x)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
else:
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
assert "data" not in args
for target in ["rocm"]: # "rocm -libs=miopen"
print("Target: ", target)
ctx = tvm.context(target, 0)
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
print("max abs diff: ", np.max(np.abs(mx_out - tvm_out)))
def test_forward_vgg():
for n in [11]:
mx_sym = model_zoo.mx_vgg(n)
verify_mxnet_frontend_impl(mx_sym)
def test_forward_resnet():
for n in [18]:
mx_sym = model_zoo.mx_resnet(18)
verify_mxnet_frontend_impl(mx_sym)
test_forward_vgg() # works
#test_forward_resnet() # HSA_STATUS_ERROR_INVALID_ISA (0x100f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment