Skip to content

Instantly share code, notes, and snippets.

@eqy
Created November 27, 2018 19:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save eqy/2a07911ea9ceb1276c9379a40af15d7c to your computer and use it in GitHub Desktop.
Save eqy/2a07911ea9ceb1276c9379a40af15d7c to your computer and use it in GitHub Desktop.
yolov3.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Wrapper around the DarkNet model conversion by Siju Samuel for convenience.
'''
import sys
import nnvm
import numpy as np
from nnvm import to_relay
def get_workload():
from tvm.contrib.download import download
from nnvm.testing.darknet import __darknetffi__
MODEL_NAME = 'yolov3'
CFG_NAME = MODEL_NAME + '.cfg'
WEIGHTS_NAME = MODEL_NAME + '.weights'
REPO_URL = 'https://github.com/siju-samuel/darknet/blob/master/'
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
WEIGHTS_URL = 'https://pjreddie.com/media/files/' + WEIGHTS_NAME
# Download and Load darknet library
if sys.platform in ['linux', 'linux2']:
DARKNET_LIB = 'libdarknet2.0.so'
DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
elif sys.platform == 'darwin':
DARKNET_LIB = 'libdarknet_mac2.0.so'
DARKNET_URL = REPO_URL + 'lib_osx/' + DARKNET_LIB + '?raw=true'
else:
err = "Darknet lib is not supported on {} platform".format(sys.platform)
raise NotImplementedError(err)
#TODO tempfile
download(CFG_URL, CFG_NAME)
download(WEIGHTS_URL, WEIGHTS_NAME)
download(DARKNET_URL, DARKNET_LIB)
DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
cfg = "./" + str(CFG_NAME)
weights = "./" + str(WEIGHTS_NAME)
net = DARKNET_LIB.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
dtype = 'float32'
#batch_size = 1
print("Converting darknet to nnvm symbols...")
net, params = nnvm.frontend.darknet.from_darknet(net, dtype)
return net, params
if __name__ == '__main__':
sym, params = get_workload()
input_shape = (1, 3, 608, 608)
shapes = {'data': input_shape}
print("Converting to relay...")
to_relay.from_nnvm(graph, shapes, {})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment