Last active
January 7, 2021 19:57
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
import tvm.relay.testing.tf as tf_testing | |
from tvm import relay | |
from tvm.relay.frontend.tensorflow_parser import TFParser | |
from tvm.relay.op.contrib import tensorrt | |
import numpy as np | |
# Usage: | |
# wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz | |
# tar xvf ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz | |
# python3 repro_nms_crash.py | |
def get_husky_image(): | |
from tvm.contrib.download import download_testdata | |
from PIL import Image | |
img_path = download_testdata( | |
"https://raw.githubusercontent.com/NVIDIA-AI-IOT/tf_trt_models/master/examples/detection/data/huskies.jpg", | |
"huskies.jpg", | |
module="data", | |
) | |
image = Image.open(img_path) | |
image = np.array(image.resize((512, 512), Image.ANTIALIAS)) | |
np_img = np.expand_dims(image, 0) | |
return np_img | |
output_tensor_names = ['detection_classes', 'num_detections', 'detection_boxes', 'detection_scores'] | |
parser = TFParser("ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03/saved_model", outputs=output_tensor_names) | |
graph_def = parser.parse() | |
graph_def = tf_testing.ProcessGraphDefParam(graph_def) | |
mod, params = relay.frontend.from_tensorflow( | |
graph_def, shape={"image_tensor": (1, 512, 512, 3)}, outputs=output_tensor_names | |
) | |
mod, config = tensorrt.partition_for_tensorrt(mod, params, remove_no_mac_subgraphs=True) | |
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}, disabled_pass=["FoldScaleAxis"]): | |
vm_exec = relay.vm.compile(mod, "cuda", params=params) | |
vm = tvm.runtime.vm.VirtualMachine(vm_exec, tvm.gpu(0)) | |
i_data = get_husky_image() | |
res = vm.invoke("main", i_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment