Last active
May 3, 2019 21:04
-
-
Save roywei/330e29b392d017f8bb01db17b96881a5 to your computer and use it in GitHub Desktop.
Gluon Estimator Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Gluon Estimator example on MNIST dataset with simple CNN""" | |
import sys | |
import os | |
import time | |
import mxnet as mx | |
from mxnet import gluon, nd, autograd | |
from mxnet import metric | |
from mxnet.gluon import nn, data | |
from mxnet.gluon.contrib.estimator import Estimator | |
def load_data_mnist(batch_size, resize=None, root=os.path.join( | |
'~', '.mxnet', 'datasets', 'mnist')): | |
# Helper function to prepare dataset | |
root = os.path.expanduser(root) # Expand the user path '~'. | |
transformer = [] | |
if resize: | |
transformer += [data.vision.transforms.Resize(resize)] | |
transformer += [data.vision.transforms.ToTensor()] | |
transformer = data.vision.transforms.Compose(transformer) | |
mnist_train = data.vision.MNIST(root=root, train=True) | |
mnist_test = data.vision.MNIST(root=root, train=False) | |
num_workers = 0 if sys.platform.startswith('win32') else 4 | |
train_iter = data.DataLoader( | |
mnist_train.transform_first(transformer), batch_size, shuffle=True, | |
num_workers=num_workers) | |
test_iter = data.DataLoader( | |
mnist_test.transform_first(transformer), batch_size, shuffle=False, | |
num_workers=num_workers) | |
return train_iter, test_iter | |
net = nn.HybridSequential() | |
net.add(nn.Conv2D(32, kernel_size=3, activation='relu'), | |
nn.Conv2D(64, kernel_size=3, activation='relu'), | |
nn.MaxPool2D(pool_size=2), | |
nn.Dropout(0.25), | |
nn.Flatten(), | |
nn.Dense(128, activation="relu"), nn.Dropout(0.5), | |
nn.Dropout(0.5), | |
nn.Dense(10)) | |
ctx = [mx.gpu(i) for i in mx.test_utils.list_gpus()] if len(mx.test_utils.list_gpus()) > 0 else [mx.cpu()] | |
batch_size = 128 | |
train_data, test_data = load_data_mnist(batch_size, resize=28) | |
net.initialize(ctx=ctx) | |
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01}) | |
net.hybridize() | |
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() | |
train_acc = metric.Accuracy() | |
############## Using Gluon Estimator ################### | |
est = Estimator(net=net, loss=loss_fn, metrics=train_acc, trainer=trainer, context=ctx) | |
est.fit(train_data=train_data, val_data=test_data, epochs=10) | |
############# Using imperative Gluon for loop ########## | |
def evaluate_accuracy(data_loader, net, ctx): | |
"""Evaluate accuracy of a model on the given data set.""" | |
if isinstance(ctx, mx.Context): | |
ctx = [ctx] | |
acc_sum, n = nd.array([0]), 0 | |
for batch in data_loader: | |
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx) | |
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx) | |
for X, y in zip(data, label): | |
y = y.astype('float32') | |
acc_sum += (net(X).argmax(axis=1) == y).sum().copyto(mx.cpu()) | |
n += y.size | |
acc_sum.wait_to_read() | |
return acc_sum.asscalar() / n | |
"""Train and evaluate a model.""" | |
print('training on', ctx) | |
for epoch in range(10): | |
train_l_sum, train_acc_sum, n, m, start = 0.0, 0.0, 0, 0, time.time() | |
for i, batch in enumerate(train_data): | |
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx) | |
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx) | |
with autograd.record(): | |
y_hats = [net(X) for X in data] | |
ls = [loss_fn(y_hat, y) for y_hat, y in zip(y_hats, label)] | |
for l in ls: | |
l.backward() | |
trainer.step(batch_size) | |
train_l_sum += sum([l.sum().asscalar() for l in ls]) | |
n += sum([l.size for l in ls]) | |
train_acc_sum += sum([(y_hat.argmax(axis=1) == y.astype('float32')).sum().asscalar() | |
for y_hat, y in zip(y_hats, label)]) | |
m += sum([y.size for y in label]) | |
test_acc = evaluate_accuracy(test_data, net, ctx) | |
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, ' | |
'time %.1f sec' | |
% (epoch + 1, train_l_sum / n, train_acc_sum / m, test_acc, | |
time.time() - start)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment