Skip to content

Instantly share code, notes, and snippets.

@KellenSunderland
Created April 29, 2019 21:16
Show Gist options
  • Save KellenSunderland/686522830475dfc7073b5d7a97e89d24 to your computer and use it in GitHub Desktop.
Save KellenSunderland/686522830475dfc7073b5d7a97e89d24 to your computer and use it in GitHub Desktop.
MXNet Benchmarking Script
import mxnet as mx
import numpy as np
import importlib
from collections import namedtuple
import time
def runMx(ctx,mod,data,num_batches,runType):
print('%s MXNet' % (runType))
Batch = namedtuple('Batch', ['data'])
t = 0
for b in range(0,num_batches):
dataMx = mx.nd.array(data,ctx)
toc = time.time()
mod.forward(Batch([dataMx]))
outputs = mod.get_outputs()
for out in outputs:
out.wait_to_read()
t = t + time.time() - toc
return t, out
def main():
print('# MxNet: %s %s' % (mx.__file__,mx.__version__))
C = 3
Y = 320
X = 240
batch_size = 16
symbolName = 'symbol_fcnxs'
device_id = 0
num_images_dryrun = 128
num_images_run = 1024
num_batches_dryrun = int(num_images_dryrun/batch_size)
num_batches_run = int(num_images_run/batch_size)
# create dummy batch
img = np.zeros((Y, X, C),dtype=np.uint8)
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
data = []
for i in range(0,batch_size):
data += [img]
data_shapes = (batch_size,C,Y,X)
# bind GPU
ctx = mx.gpu(device_id)
# load symbol
print('# load %s' % (symbolName))
net = importlib.import_module(symbolName)
sym = net.get_fcn8s_symbol()
# get dummy weights
executor = sym.simple_bind(ctx, data=data_shapes, grad='null')
arg_params = executor.arg_dict
aux_params = executor.aux_dict
# bind
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=[])
executor = mod.bind(for_training=False, inputs_need_grad=False, data_shapes=[('data', data_shapes)])
# set dummy weights
mod.set_params(arg_params, aux_params, allow_missing=True)
t, out = runMx(ctx,mod,data,num_batches_dryrun,'dry run')
t, out = runMx(ctx,mod,data,num_batches_run,'run')
print "# inputs: ({} {} {} {})".format(len(data),data[0].shape[0],data[0].shape[1],data[0].shape[2])
print "# outputs: ", out.shape
print "# forward time: %1.2f ms/frame, %1.2f ms total; batch_size: %d; symbol: %s" % (t*1000/(batch_size*num_batches_run),t*1000,batch_size,symbolName)
return
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment