Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Last active October 5, 2017 18:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhreshold/342131885b19a8602006e2e1477e8af7 to your computer and use it in GitHub Desktop.
Save zhreshold/342131885b19a8602006e2e1477e8af7 to your computer and use it in GitHub Desktop.
FC perf benchmark
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
def get_bench_net(num_hidden=10000):
data = mx.sym.var('data')
fc = mx.sym.FullyConnected(data, num_hidden=num_hidden)
return fc
num_out = 10000
num_example = 1024
net = get_bench_net(num_out)
mod = mx.mod.Module(net)
data = np.random.rand(num_example, 10000)
#label = np.random.rand(num_example, num_out)
nd_iter = mx.io.NDArrayIter(data, batch_size=16)
mod.bind(data_shapes=nd_iter.provide_data, label_shapes=None)
mod.init_params()
start = timer()
for _ in range(100):
mod.forward_backward()
mx.nd.waitall()
print('elaped time:', timer() - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment