Skip to content

Instantly share code, notes, and snippets.

@Wheest
Last active May 6, 2022 19:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Wheest/9ad2d6a47bbd2cfaa4be530c68ba2f6c to your computer and use it in GitHub Desktop.
Save Wheest/9ad2d6a47bbd2cfaa4be530c68ba2f6c to your computer and use it in GitHub Desktop.
TVM Debugger Simple Example
#!/usr/bin/env python
import argparse
import os
import sys
from PIL import Image
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_executor
from tvm.contrib.debugger import debug_executor
from tvm.contrib.download import download_testdata
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
def transform_image(image):
image = np.array(image) - np.array([123.0, 117.0, 104.0])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image
def main(args):
block = get_model("resnet18_v1", pretrained=True)
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_name = "cat.png"
synset_url = "".join(
[
"https://gist.githubusercontent.com/zhreshold/",
"4d0b62f3d01426887599d4f7ede23ee5/raw/",
"596b27d23537e5a1b5751d2b0481ef172f58b539/",
"imagenet1000_clsid_to_human.txt",
]
)
synset_name = "imagenet1000_clsid_to_human.txt"
img_path = download_testdata(img_url, "cat.png", module="data")
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))
data = transform_image(image)
shape_dict = {"data": data.shape}
mod, params = relay.frontend.from_mxnet(block, shape_dict)
## we want a probability so add a softmax operator
func = mod["main"]
func = relay.Function(
func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs
)
input_name = "input0"
shape_list = [(input_name, data.shape)]
out_shape = (1, 1000)
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
dtype = "float32"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
if args.mode == "tutorial":
# Try and run the profiler according to the TVM docs at https://tvm.apache.org/docs/dev/debugger.html
m = debug_executor.create(lib.graph_json, lib, dev, dump_root="/tmp/tvmdbg")
print("Compiled module")
elif args.mode == "alt":
from tvm.contrib.debugger.debug_executor import GraphModuleDebug
m = GraphModuleDebug(
lib["debug_create"]("default", dev),
[dev],
lib.graph_json,
dump_root="/tmp/tvmdbg",
)
elif args.mode == "normal":
m = graph_executor.GraphModule(lib["default"](dev))
print(f"Created {args.mode} module without dying")
# set inputs
m.set_input("data", tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
# execute
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).numpy()
print(f"Probably succesfully ran {args.mode} inference")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a model using the TVM debugger")
parser.add_argument(
"--mode",
default="tutorial",
choices=["tutorial", "alt", "normal"],
help="What version to try running",
)
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment