Skip to content

Instantly share code, notes, and snippets.

@ThomasDelteil
Last active June 16, 2018 21:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ThomasDelteil/b9ee7171c18e2db3ba47cb4c0439f116 to your computer and use it in GitHub Desktop.
Save ThomasDelteil/b9ee7171c18e2db3ba47cb4c0439f116 to your computer and use it in GitHub Desktop.
benchmarkHybridBlock
import multiprocessing, time
import mxnet as mx
from mxnet import nd, gluon, autograd
# ctx
ctx = mx.gpu()
# data
def transform(x):
#x = mx.image.resize_short(x, 32)
x = x.transpose((2, 0, 1))
x = x.repeat(axis=0, repeats=3)
return x.astype('float32')
batch_size = 128
mnist_train = gluon.data.vision.datasets.MNIST(train=True).transform_first(transform)
train_data = gluon.data.DataLoader(
mnist_train, batch_size=batch_size, shuffle=True,
num_workers=multiprocessing.cpu_count()-2)
# loss
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
# get network
def get_net(static_alloc=False, static_shape=True):
net = gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=mx.gpu())
with net.name_scope():
net.output = gluon.nn.Dense(10)
net.output.initialize(ctx=ctx)
net.hybridize(static_alloc=static_alloc, static_shape=static_shape)
return net
net = None
# training loops
for static_alloc in [True, False]:
for static_shape in [True, False]:
del net
net = get_net(static_alloc, static_shape)
trainer = gluon.Trainer(net.collect_params(),
'sgd', {'learning_rate': 0.1})
net(mx.nd.ones((batch_size, 3, 224, 224), ctx))
tick = time.time()
for epoch in range(3):
loss_acc = 0
for data, label in train_data:
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
loss_acc += loss.mean().asscalar() # blocking
trainer.step(data.shape[0])
print("Epoch [{}]: loss {:.4f}".format(epoch, loss_acc/len(train_data)))
print("static_alloc:{}, static_shape:{}, time:{:.4f} \n".format(
static_alloc, static_shape, time.time()-tick
))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment