Skip to content

Instantly share code, notes, and snippets.

@amohant4
Last active July 17, 2019 18:31
Show Gist options
  • Save amohant4/b47a6563637c6a8bbc133b5f5bbce4c9 to your computer and use it in GitHub Desktop.
Save amohant4/b47a6563637c6a8bbc133b5f5bbce4c9 to your computer and use it in GitHub Desktop.
import mxnet as mx
from mxnet import init
# Create a mxnet symbol for the graph ~~~
def create_net_moduleAPI():
"""
Method to create a symbol for LeNet in MXNet.
Arguments: None
Returns: mx.sym for LeNet
"""
net = mx.sym.Variable('data')
net = mx.sym.Convolution(net, name='conv1', num_filter=6, kernel=(5,5))
net = mx.sym.Activation(net, name='conv1_relu', act_type="relu")
net = mx.sym.Pooling(net, name='maxpool1', pool_type='max', kernel=(2,2), stride=(2,2))
net = mx.sym.Convolution(net, name='conv2', num_filter=16, kernel=(5,5))
net = mx.sym.Activation(net, name='conv2_relu', act_type="relu")
net = mx.sym.Pooling(net, name='maxpool2', pool_type='max', kernel=(2,2), stride=(2,2))
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=120)
net = mx.sym.Activation(net, name='fc1relu', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=84)
net = mx.sym.Activation(net, name='fc2relu', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc3', num_hidden=10)
net = mx.sym.SoftmaxOutput(net, name='softmax')
return net
# Use the symbol to get create a Module object. ~~~
mod = mx.mod.Module(symbol=create_net_moduleAPI(), # Symbol of the graph
context=mx.cpu(), # mx.gpu() if you got GPUs
data_names=['data'], # name of the symbol which has
label_names=['softmax_label']) # final output label. '_label' is appended by mxnet
# Get the dataset ~~~~
# Using MXNet's predefined utilities to make life easier
mnist = mx.test_utils.get_mnist()
batch_size = 256
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
# Bind the module to data (infers shape of nodes)
mod.bind(data_shapes=[('data',(256,1,28,28))], label_shapes=[('softmax_label',(256,))])
#mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=init.Xavier()) # Initialize parameters
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), )) # initialize Optimizer
metric = mx.metric.create('acc') # We are using accuracy as an metric here to see how good we are doing.
# Training loop ~~~~
for epoch in range(5):
train_iter.reset() # Reset training data iter to start fresh for this epoch
metric.reset() # Reset metrics so as to accumulate for this epoch
for batch in train_iter:
mod.forward(batch, is_train=True) # Foward pass on the batch
mod.update_metric(metric, batch.label) # Update accuracy on the batch
mod.backward() # Trickle the error back and get gradient on error
mod.update() # Update the parameters
print('Epoch %d, Training %s' % (epoch, metric.get())) # Print accuracy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment