Created
May 22, 2018 00:37
-
-
Save masahi/a386c2ce5b5f8c2d9f7af5e09a8d880b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import nnvm | |
import tvm | |
import numpy as np | |
from mxnet.gluon.model_zoo.vision import get_model | |
from mxnet.gluon.utils import download | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
model = "mobilenet1.0" | |
#model = "resnet50_v1" | |
block = get_model(model, pretrained=True) | |
x = np.ones((1, 3, 224, 224)) | |
sym, params = nnvm.frontend.from_mxnet(block) | |
import nnvm.compiler | |
target = 'opencl' | |
shape_dict = {'data': x.shape} | |
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params) | |
from tvm.contrib import graph_runtime | |
ctx = tvm.context(target, 0) | |
dtype = 'float32' | |
m = graph_runtime.create(graph, lib, ctx) | |
m.set_input('data', tvm.nd.array(x.astype(dtype))) | |
m.set_input(**params) | |
m.run() | |
tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype)) | |
num_iter = 100 | |
ftimer = m.module.time_evaluator("run", ctx, num_iter) | |
elapsed = ftimer().mean | |
print("Average of %d runs: %f sec, %d inference/sec" % (num_iter, elapsed, 1.0/elapsed)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment