Skip to content

Instantly share code, notes, and snippets.

@nervetumer
Created May 26, 2016 22:25
Show Gist options
  • Save nervetumer/976ea1f6e6fa1b96db2b7b9bdd26cd6b to your computer and use it in GitHub Desktop.
Save nervetumer/976ea1f6e6fa1b96db2b7b9bdd26cd6b to your computer and use it in GitHub Desktop.
conv
#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2015 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""
Example that trains a small convolutional model on the MNIST dataset
This example has some command line arguments that enable different neon features.
Examples:
python mnist_conv.py -b cpu -e 10
Run the example for 10 epochs of mnist data using the CPU backend
python mnist_conv.py -b cpu --eval_freq 1
After each training epoch the validation/test data set will be processed through the model
and the cost will be displayed.
"""
import logging
from neon.data import ArrayIterator, load_mnist
from neon.initializers import Constant, Kaiming
from neon.layers import GeneralizedCost, Affine, Conv, Pooling
from neon.models import Model
from neon.optimizers import GradientDescentMomentum
from neon.transforms import Rectlin, CrossEntropyMulti, Misclassification, Rectlin, Softmax
from neon.callbacks.callbacks import Callbacks
from neon.util.argparser import NeonArgparser
# parse the command line arguments
parser = NeonArgparser(__doc__)
args = parser.parse_args()
logger = logging.getLogger()
# load up the mnist data set
# split into train and tests sets
(X_train, y_train), (X_test, y_test), nclass = load_mnist(path=args.data_dir)
# setup a training set iterator
train_set = ArrayIterator(X_train, y_train, nclass=nclass, lshape=(1, 28, 28))
# setup a validation data set iterator
valid_set = ArrayIterator(X_test, y_test, nclass=nclass, lshape=(1, 28, 28))
# setup weight initialization function
init_norm = Kaiming()
init_const = Constant(0.0)
# setup model layers
layers = [Conv((5, 5, 8), padding=2, init=init_norm),
Pooling((2,2), strides=2),
Conv((5, 5, 16), init=init_norm),
Pooling((2,2), strides=2),
Affine(nout=120, init=init_norm, activation=Rectlin(), bias=init_const),
Affine(nout=84, init=init_norm, activation=Rectlin(), bias=init_const),
Affine(nout=10, init=init_norm, activation=Softmax())]
# setup cost function as CrossEntropy
cost = GeneralizedCost(costfunc=CrossEntropyMulti())
# setup optimizer
optimizer = GradientDescentMomentum(0.01, momentum_coef=0.9)
# initialize model object
model = Model(layers=layers)
# configure callbacks
callbacks = Callbacks(model, eval_set=valid_set, metric=Misclassification(), **args.callback_args)
# run fit
model.fit(train_set, optimizer=optimizer, num_epochs=args.epochs, cost=cost, callbacks=callbacks)
print('Misclassification error = %.1f%%' % (model.eval(valid_set, metric=Misclassification())*100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment