Created
May 26, 2016 22:25
-
-
Save nervetumer/976ea1f6e6fa1b96db2b7b9bdd26cd6b to your computer and use it in GitHub Desktop.
conv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# ---------------------------------------------------------------------------- | |
# Copyright 2015 Nervana Systems Inc. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ---------------------------------------------------------------------------- | |
""" | |
Example that trains a small convolutional model on the MNIST dataset | |
This example has some command line arguments that enable different neon features. | |
Examples: | |
python mnist_conv.py -b cpu -e 10 | |
Run the example for 10 epochs of mnist data using the CPU backend | |
python mnist_conv.py -b cpu --eval_freq 1 | |
After each training epoch the validation/test data set will be processed through the model | |
and the cost will be displayed. | |
""" | |
import logging | |
from neon.data import ArrayIterator, load_mnist | |
from neon.initializers import Constant, Kaiming | |
from neon.layers import GeneralizedCost, Affine, Conv, Pooling | |
from neon.models import Model | |
from neon.optimizers import GradientDescentMomentum | |
from neon.transforms import Rectlin, CrossEntropyMulti, Misclassification, Rectlin, Softmax | |
from neon.callbacks.callbacks import Callbacks | |
from neon.util.argparser import NeonArgparser | |
# parse the command line arguments | |
parser = NeonArgparser(__doc__) | |
args = parser.parse_args() | |
logger = logging.getLogger() | |
# load up the mnist data set | |
# split into train and tests sets | |
(X_train, y_train), (X_test, y_test), nclass = load_mnist(path=args.data_dir) | |
# setup a training set iterator | |
train_set = ArrayIterator(X_train, y_train, nclass=nclass, lshape=(1, 28, 28)) | |
# setup a validation data set iterator | |
valid_set = ArrayIterator(X_test, y_test, nclass=nclass, lshape=(1, 28, 28)) | |
# setup weight initialization function | |
init_norm = Kaiming() | |
init_const = Constant(0.0) | |
# setup model layers | |
layers = [Conv((5, 5, 8), padding=2, init=init_norm), | |
Pooling((2,2), strides=2), | |
Conv((5, 5, 16), init=init_norm), | |
Pooling((2,2), strides=2), | |
Affine(nout=120, init=init_norm, activation=Rectlin(), bias=init_const), | |
Affine(nout=84, init=init_norm, activation=Rectlin(), bias=init_const), | |
Affine(nout=10, init=init_norm, activation=Softmax())] | |
# setup cost function as CrossEntropy | |
cost = GeneralizedCost(costfunc=CrossEntropyMulti()) | |
# setup optimizer | |
optimizer = GradientDescentMomentum(0.01, momentum_coef=0.9) | |
# initialize model object | |
model = Model(layers=layers) | |
# configure callbacks | |
callbacks = Callbacks(model, eval_set=valid_set, metric=Misclassification(), **args.callback_args) | |
# run fit | |
model.fit(train_set, optimizer=optimizer, num_epochs=args.epochs, cost=cost, callbacks=callbacks) | |
print('Misclassification error = %.1f%%' % (model.eval(valid_set, metric=Misclassification())*100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment