Skip to content

Instantly share code, notes, and snippets.

@masahi
Created June 17, 2021 21:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masahi/9348db919edb105912b94b84792dd7d3 to your computer and use it in GitHub Desktop.
Save masahi/9348db919edb105912b94b84792dd7d3 to your computer and use it in GitHub Desktop.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# See:
# - https://tvm.apache.org/docs/tutorials/frontend/from_onnx.html
# - https://github.com/apache/tvm/blob/main/tutorials/frontend/from_onnx.py
# - https://github.com/onnx/models
import subprocess
import os
import sys
import posixpath
from six.moves.urllib.request import urlretrieve
import glob
import onnx
from onnx import numpy_helper
import numpy as np
import tvm
import tvm.relay as relay
from tvm.contrib import graph_executor
from tvm.runtime.vm import VirtualMachine
def get_value_info_shape(value_info):
return tuple([max(d.dim_value, 1) for d in value_info.type.tensor_type.shape.dim])
urls = [
'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.tar.gz',
'https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz',
# 'https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-base-11.tar.gz',
# XXX: Often segfaults
'https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.tar.gz',
]
target = "llvm"
ctx = tvm.device(target, 0)
summary = []
for url in urls:
print(f'==> {url} <==')
archive = posixpath.basename(url)
if not os.path.exists(archive):
print(f'Downloading {url} ...')
urlretrieve(url, archive)
assert os.path.exists(archive)
import tarfile
tar = tarfile.open(archive, 'r:gz')
for n in tar.getnames():
if n.endswith('.onnx'):
model_file = n
name = os.path.dirname(n)
break
if not os.path.exists(model_file):
print(f'Extracting {archive} ...')
#subprocess.call(['tar', 'xzf', archive])
tar.extractall()
assert os.path.exists(model_file)
print(f'Loading {model_file} ...')
onnx_model = onnx.load(model_file)
graph = onnx_model.graph
initializers = set()
for initializer in graph.initializer:
initializers.add(initializer.name)
input_values = []
test_data_set = glob.glob(os.path.join(name, 'test_data_set_*'))[0]
shape_dict = {}
assert os.path.exists(test_data_set)
inputs = {}
for input in graph.input:
if input.name not in initializers:
i = len(input_values)
input_data = os.path.join(test_data_set, f'input_{i}.pb')
tensor = onnx.TensorProto()
input_data = open(input_data, 'rb').read()
tensor.ParseFromString(input_data)
x = numpy_helper.to_array(tensor)
input_values.append(x)
shape_dict[input.name] = x.shape
inputs[input.name] = tvm.nd.array(x, ctx)
print(f'Input shapes: {shape_dict}')
try:
print(f'Importing graph from ONNX to TVM Relay IR ...')
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
mod = relay.transform.DynamicToStatic()(mod)
print(f'Compiling graph from Relay IR to {target} ...')
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target, params=params)
dev = tvm.device(target, 0)
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **inputs)
print(f"Running inference...")
vm.run()
except KeyboardInterrupt:
raise
except Exception as ex:
print(f'Caught an exception {ex}')
result = 'not ok'
else:
print(f'Succeeded!')
result = 'ok'
summary.append((result, url))
print()
print('Summary:')
for result, url in summary:
print(f'{result}\t- {url}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment