Skip to content

Instantly share code, notes, and snippets.

@eqy
Created April 23, 2019 00:20
Show Gist options
  • Save eqy/b71d04b73842ce214819ad4c4930e7a1 to your computer and use it in GitHub Desktop.
Save eqy/b71d04b73842ce214819ad4c4930e7a1 to your computer and use it in GitHub Desktop.
prepare_model.py
import logging
import mxnet as mx
import tvm
import nnvm.frontend
import nnvm.compiler
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
from tvm import relay
from tvm.contrib import ndk
import os
target_host = 'llvm -target=arm64-linux-android'
def get_model(model_name, batch_size=1):
#supported_models = ["resnet18_v1", "resnet34_v1", "inceptionv3"]
gluon_model = vision.get_model(model_name, pretrained=True)
img_size = 299 if model_name == 'inceptionv3' else 224
data_shape = (batch_size, 3, img_size, img_size)
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
with relay.build_config(opt_level=3):
func = relay.optimize(net, target=None, params=params)
return func
def get_model_nnvm(model_name, batch_size=1):
gluon_model = vision.get_model(model_name, pretrained=True)
img_size = 299 if model_name == 'inceptionv3' else 224
data_shape = (batch_size, 3, img_size, img_size)
sym, params = nnvm.frontend.from_mxnet(gluon_model, {"data": data_shape})
return sym, params
def main_nnvm(model_str):
print(model_str)
print("getting model...")
sym, params = get_model_nnvm(model_str)
try:
os.mkdir(model_str)
except FileExistsError:
pass
#sym, params = get_model_nnvm('resnet18_v1')
target = tvm.target.arm_cpu(model='pixel2')
print("building model...")
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(sym, target, {"data": (1, 3,
224, 224)}, params=params, target_host=None)
print("dumping lib...")
lib.export_library(model_str + '/' + 'deploy_lib_cpu.so', ndk.create_shared)
print("dumping graph...")
with open(model_str + '/' + 'deploy_graph.json', 'w') as f:
f.write(graph.json())
print("dumping params...")
with open (model_str + '/' + 'deploy_param.params', 'wb') as f:
f.write(nnvm.compiler.save_param_dict(params))
def main(model_str):
print(model_str)
print("getting model...")
func = get_model(model_str)
try:
os.mkdir(model_str)
except FileExistsError:
pass
#func = get_model('mobilenet1.0')
print("building...")
target = tvm.target.arm_cpu(model='pixel2')
print("(relay)")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, target_host=target_host)
print("dumping lib...")
lib.export_library(model_str + '/' + 'deploy_lib_cpu.so', ndk.create_shared)
print("dumping graph...")
with open(model_str + '/' + 'deploy_graph.json', 'w') as f:
f.write(graph)
print("dumping params...")
with open(model_str + '/' + 'deploy_param.params', 'wb') as f:
f.write(relay.save_param_dict(params))
if __name__ == '__main__':
models = ['mobilenet1.0', 'mobilenetv2_1.0', 'resnet18_v1', 'inceptionv3',
'squeezenet1.1']
for model in models:
main(model)
#main_nnvm(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment