-
-
Save Wheest/9ad2d6a47bbd2cfaa4be530c68ba2f6c to your computer and use it in GitHub Desktop.
TVM Debugger Simple Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import os | |
import sys | |
from PIL import Image | |
import numpy as np | |
import tvm | |
from tvm import relay | |
from tvm.contrib import graph_executor | |
from tvm.contrib.debugger import debug_executor | |
from tvm.contrib.download import download_testdata | |
import mxnet as mx | |
from mxnet.gluon.model_zoo.vision import get_model | |
def transform_image(image): | |
image = np.array(image) - np.array([123.0, 117.0, 104.0]) | |
image /= np.array([58.395, 57.12, 57.375]) | |
image = image.transpose((2, 0, 1)) | |
image = image[np.newaxis, :] | |
return image | |
def main(args): | |
block = get_model("resnet18_v1", pretrained=True) | |
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" | |
img_name = "cat.png" | |
synset_url = "".join( | |
[ | |
"https://gist.githubusercontent.com/zhreshold/", | |
"4d0b62f3d01426887599d4f7ede23ee5/raw/", | |
"596b27d23537e5a1b5751d2b0481ef172f58b539/", | |
"imagenet1000_clsid_to_human.txt", | |
] | |
) | |
synset_name = "imagenet1000_clsid_to_human.txt" | |
img_path = download_testdata(img_url, "cat.png", module="data") | |
synset_path = download_testdata(synset_url, synset_name, module="data") | |
with open(synset_path) as f: | |
synset = eval(f.read()) | |
image = Image.open(img_path).resize((224, 224)) | |
data = transform_image(image) | |
shape_dict = {"data": data.shape} | |
mod, params = relay.frontend.from_mxnet(block, shape_dict) | |
## we want a probability so add a softmax operator | |
func = mod["main"] | |
func = relay.Function( | |
func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs | |
) | |
input_name = "input0" | |
shape_list = [(input_name, data.shape)] | |
out_shape = (1, 1000) | |
target = tvm.target.Target("llvm", host="llvm") | |
dev = tvm.cpu(0) | |
dtype = "float32" | |
with tvm.transform.PassContext(opt_level=3): | |
lib = relay.build(mod, target=target, params=params) | |
if args.mode == "tutorial": | |
# Try and run the profiler according to the TVM docs at https://tvm.apache.org/docs/dev/debugger.html | |
m = debug_executor.create(lib.graph_json, lib, dev, dump_root="/tmp/tvmdbg") | |
print("Compiled module") | |
elif args.mode == "alt": | |
from tvm.contrib.debugger.debug_executor import GraphModuleDebug | |
m = GraphModuleDebug( | |
lib["debug_create"]("default", dev), | |
[dev], | |
lib.graph_json, | |
dump_root="/tmp/tvmdbg", | |
) | |
elif args.mode == "normal": | |
m = graph_executor.GraphModule(lib["default"](dev)) | |
print(f"Created {args.mode} module without dying") | |
# set inputs | |
m.set_input("data", tvm.nd.array(data.astype(dtype))) | |
m.set_input(**params) | |
# execute | |
m.run() | |
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).numpy() | |
print(f"Probably succesfully ran {args.mode} inference") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run a model using the TVM debugger") | |
parser.add_argument( | |
"--mode", | |
default="tutorial", | |
choices=["tutorial", "alt", "normal"], | |
help="What version to try running", | |
) | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment