Skip to content

Instantly share code, notes, and snippets.

@wkcn
Created May 13, 2018 08:06
Show Gist options
  • Save wkcn/b17dd1bc01c363fe2f244eaa29ceb94a to your computer and use it in GitHub Desktop.
Save wkcn/b17dd1bc01c363fe2f244eaa29ceb94a to your computer and use it in GitHub Desktop.
mnist_count_time
import mxnet as mx
import count_time
mnist = mx.test_utils.get_mnist()
# Fix the seed
mx.random.seed(42)
# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
data = mx.sym.var('data')
data, t = mx.sym.Custom(op_type = 'CountTimeOP', data = data, first = True, cname = 't0')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)
# The first fully-connected layer and the corresponding activation function
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
act1, t = mx.sym.Custom(op_type = 'CountTimeOP', data = act1, t = t, first = False, cname = 't1')
# The second fully-connected layer and the corresponding activation function
fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
act2, t = mx.sym.Custom(op_type = 'CountTimeOP', data = act2, t = t, first = False, cname = 't2')
# MNIST has 10 classes
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
import logging
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout
# create a trainable module on compute context
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
mlp_model.fit(train_iter, # train data
eval_data=val_iter, # validation data
optimizer='sgd', # use SGD to train
optimizer_params={'learning_rate':0.1}, # use fixed learning rate
eval_metric='acc', # report accuracy during training
batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
num_epoch=10) # train for at most 10 dataset passes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment