Last active
September 12, 2016 21:18
-
-
Save glamp/0c721b0a0225ee19a0bee8003b5ac564 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import numpy as np | |
import cv2 | |
import logging | |
logger = logging.getLogger() | |
logger.setLevel(logging.DEBUG) | |
# Variables are place holders for input arrays. We give each variable a unique name. | |
data = mx.symbol.Variable('data') | |
# The input is fed to a fully connected layer that computes Y=WX+b. | |
# This is the main computation module in the network. | |
# Each layer also needs an unique name. We'll talk more about naming in the next section. | |
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) | |
# Activation layers apply a non-linear function on the previous layer's output. | |
# Here we use Rectified Linear Unit (ReLU) that computes Y = max(X, 0). | |
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") | |
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) | |
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") | |
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) | |
# Finally we have a loss layer that compares the network's output with label and generates gradient signals. | |
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') | |
# mx.viz.plot_network(mlp) | |
mlp.list_arguments() | |
from sklearn.datasets import fetch_mldata | |
mnist = fetch_mldata('MNIST original') | |
np.random.seed(1234) # set seed for deterministic ordering | |
p = np.random.permutation(mnist.data.shape[0]) | |
X = mnist.data[p] | |
Y = mnist.target[p] | |
# for i in range(10): | |
# plt.subplot(1,10,i+1) | |
# plt.imshow(X[i].reshape((28,28)), cmap='Greys_r') | |
# plt.axis('off') | |
# plt.show() | |
X = X.astype(np.float32)/255 | |
X_train = X[:60000] | |
X_test = X[60000:] | |
Y_train = Y[:60000] | |
Y_test = Y[60000:] | |
batch_size = 100 | |
train_iter = mx.io.NDArrayIter(X_train, Y_train, batch_size=batch_size) | |
test_iter = mx.io.NDArrayIter(X_test, Y_test, batch_size=batch_size) | |
model = mx.model.FeedForward( | |
ctx = mx.cpu(), # Run on GPU 0 | |
symbol = mlp, # Use the network we just defined | |
num_epoch = 10, # Train for 10 epochs | |
learning_rate = 0.1, # Learning rate | |
momentum = 0.9, # Momentum for SGD with momentum | |
wd = 0.00001) # Weight decay for regularization | |
model.fit( | |
X=train_iter, # Training data set | |
eval_data=test_iter, # Testing data set. MXNet computes scores on test set every epoch | |
batch_end_callback = mx.callback.Speedometer(batch_size, 200)) # Logging module to print out progress | |
print model.predict(X_test[0:1])[0].argmax() | |
print 'Accuracy:', model.score(test_iter)*100, '%' | |
from yhat import Yhat, YhatModel | |
class MxNetModel(YhatModel): | |
def execute(self, data): | |
img_list = data['image'] | |
img = np.array([img_list]) | |
return { "guess": model.predict(img)[0].argmax() } | |
testcase = { 'image': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 159, 253, 159, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 238, 252, 252, 252, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 54, 227, 253, 252, 239, 233, 252, 57, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 60, 224, 252, 253, 252, 202, 84, 252, 253, 122, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 163, 252, 252, 252, 253, 252, 252, 96, 189, 253, 167, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 238, 253, 253, 190, 114, 253, 228, 47, 79, 255, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 238, 252, 252, 179, 12, 75, 121, 21, 0, 0, 253, 243, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38, 165, 253, 233, 208, 84, 0, 0, 0, 0, 0, 0, 253, 252, 165, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 178, 252, 240, 71, 19, 28, 0, 0, 0, 0, 0, 0, 253, 252, 195, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 57, 252, 252, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 252, 195, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 198, 253, 190, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 253, 196, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 76, 246, 252, 112, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 252, 148, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 252, 230, 25, 0, 0, 0, 0, 0, 0, 0, 0, 7, 135, 253, 186, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 252, 223, 0, 0, 0, 0, 0, 0, 0, 0, 7, 131, 252, 225, 71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 252, 145, 0, 0, 0, 0, 0, 0, 0, 48, 165, 252, 173, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 86, 253, 225, 0, 0, 0, 0, 0, 0, 114, 238, 253, 162, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 252, 249, 146, 48, 29, 85, 178, 225, 253, 223, 167, 56, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 252, 252, 252, 229, 215, 252, 252, 252, 196, 130, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 28, 199, 252, 252, 253, 252, 252, 233, 145, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 128, 252, 253, 252, 141, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} | |
print MxNetModel().execute(testcase) | |
yh = Yhat(USERNAME, APIKEY, URL) | |
print yh.deploy("MxNetExample", MxNetModel, globals(), sure=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment