Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Last active September 9, 2022 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save apivovarov/cb148ffbb55a37d26d3ef6e2503b0fc8 to your computer and use it in GitHub Desktop.
Save apivovarov/cb148ffbb55a37d26d3ef6e2503b0fc8 to your computer and use it in GitHub Desktop.
compile MXNet nn.HybridBlock
import tvm
from tvm import relay
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
print(mx.__version__)
ctx = mx.cpu()
model_name, input_name, data_type = ("mxnet_shape", "data", "float32")
# https://mxnet.apache.org/versions/1.7/api/python/docs/api/gluon/hybrid_block.html
class Net(nn.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
def hybrid_forward(self, F, x):
y = F.shape_array(x)
return y
net = Net()
net.initialize(ctx=mx.cpu(0))
net.hybridize()
x=mx.nd.array([2,3,0,1,1,1,2], ctx=ctx, dtype="float32")
y=net(x)
print(y, y.dtype)
net.export(model_name)
print("Loading mxnet model...", model_name)
block = gluon.nn.SymbolBlock.imports(model_name+"-symbol.json", [input_name], model_name+"-0000.params", ctx=ctx)
shape_dict = {"data": [7]}
print("relay.frontend.from_mxnet...")
mod, params = relay.frontend.from_mxnet(block, shape_dict)
print("Parsing Done")
print(mod["main"])
target="llvm"
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, params=params, target=target)
print("Compilation Done")
vm = tvm.runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
ff_data = {input_name:x.asnumpy()}
res = vm.run(**ff_data)
print("vm.run:", res, res.dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment