Skip to content

Instantly share code, notes, and snippets.

@mlzxy
Forked from zhreshold/tune_cifar.py
Created April 28, 2018 04:20
Show Gist options
  • Save mlzxy/1c2c850492839a08ff0bd9f42bc32166 to your computer and use it in GitHub Desktop.
Save mlzxy/1c2c850492839a08ff0bd9f42bc32166 to your computer and use it in GitHub Desktop.
Cifar10 with gluon model
import argparse
import logging
import random
import time
import mxnet as mx
from mxnet import nd
from mxnet import image
from mxnet import gluon
from mxnet import autograd
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description="Train CIFAR10.")
parser.add_argument('--batch-size', type=int, default=128,
help='training batch size.')
parser.add_argument('--num-gpus', type=int, default=1,
help='number of gpus to use.')
parser.add_argument('--epochs', type=int, default=350,
help='number of training epochs.')
parser.add_argument('--lr', type=float, default=0.1,
help='learning rate. default is 0.01.')
parser.add_argument('-momentum', type=float, default=0.9,
help='momentum value for optimizer, default is 0.9.')
parser.add_argument('--wd', type=float, default=0.0001,
help='weight decay rate. default is 0.0001.')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123.')
parser.add_argument('--log-interval', type=int, default=50,
help='Number of batches to wait before logging.')
parser.add_argument('--kvstore', type=str, default='device',
help='kvstore to use for trainer/module.')
args = parser.parse_args()
return args
def get_data_rec(batch_size):
import os
data_dir="data"
def download_cifar10():
from mxnet.test_utils import download
fnames = (os.path.join(data_dir, "cifar10_train.rec"),
os.path.join(data_dir, "cifar10_val.rec"))
download('http://data.mxnet.io/data/cifar10/cifar10_val.rec', fnames[1])
download('http://data.mxnet.io/data/cifar10/cifar10_train.rec', fnames[0])
return fnames
(train_fname, val_fname) = download_cifar10()
train = mx.io.ImageRecordIter(
path_imgrec = os.path.join(data_dir, "cifar10_train.rec"),
label_width = 1,
data_name = 'data',
label_name = 'softmax_label',
data_shape = (3, 32, 32),
batch_size = batch_size,
pad = 4,
fill_value = 127, # only used when pad is valid
rand_crop = True,
max_random_scale = 1.0, # 480 with imagnet, 32 with cifar10
min_random_scale = 1.0, # 256.0/480.0
max_aspect_ratio = 0,
random_h = 0,
random_s = 0,
random_l = 0,
max_rotate_angle = 0,
max_shear_ratio = 0,
rand_mirror = True,
shuffle = True,)
val = mx.io.ImageRecordIter(
path_imgrec = os.path.join(data_dir, "cifar10_val.rec"),
label_width = 1,
data_name = 'data',
label_name = 'softmax_label',
batch_size = batch_size,
data_shape = (3, 32, 32),
rand_crop = False,
rand_mirror = False,)
return train, val
def get_data(batch_size):
# transform = lambda data, label: (data.astype('float32').transpose((2, 0, 1))/255, label)
def train_transform(data, label):
data = mx.image.imresize(data, 40, 40)
data = mx.image.RandomCropAug((32, 32))(data)
data = mx.image.HorizontalFlipAug(0.5)(data)
data = data.astype('float32') / 255
data = (data - mx.nd.array([0.4914, 0.4822, 0.4465])) / mx.nd.array([0.2023, 0.1994, 0.2010])
data = data.transpose((2, 0, 1))
return data, label
def val_transform(data, label):
data = data.astype('float32') / 255
data = (data - mx.nd.array([0.4914, 0.4822, 0.4465])) / mx.nd.array([0.2023, 0.1994, 0.2010])
data = data.transpose((2, 0, 1))
return data, label
train_dataset = gluon.data.vision.CIFAR10(transform=train_transform)
val_dataset = gluon.data.vision.CIFAR10(train=False, transform=val_transform)
train_data = gluon.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, last_batch='keep')
val_data = gluon.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, last_batch='keep')
return train_data, val_data
def get_net():
net = gluon.model_zoo.vision.get_model('resnet101_v2', thumbnail=True)
return net
def test(val_data, ctx):
try:
val_data.reset()
except:
pass
metric = mx.metric.Accuracy()
for batch in val_data:
try:
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
except:
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
for x in data:
out = net(x)
outputs.append(out)
metric.update(label, outputs)
return metric.get()
def train(net, train_data, val_data, epochs, lr, momentum, wd, ctx, kvstore, log_interval):
net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
net.hybridize()
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': lr, 'wd': wd, 'momentum': momentum},
kvstore = kvstore)
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()
logging.info("Start training on {}.".format(str(ctx)))
for epoch in range(args.epochs):
try:
train_data.reset()
except:
pass
if epoch in [150, 250]:
trainer.set_learning_rate(trainer.learning_rate * 0.1)
print('reduced learning rate to {}'.format(trainer.learning_rate))
tic = time.time()
metric.reset()
btic = time.time()
for i, batch in enumerate(train_data):
try:
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
except:
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
Ls = []
with autograd.record():
for x, y in zip(data, label):
z = net(x)
L = loss(z, y)
# store the loss and do backward after we have done forward
# on all GPUs for better speed on multiple GPUs.
Ls.append(L)
outputs.append(z)
autograd.backward(Ls)
batch_size = np.prod([d.shape[0] for d in data])
trainer.step(batch_size, ignore_stale_grad=False)
metric.update(label, outputs)
if log_interval and not (i+1)%log_interval:
name, acc = metric.get()
logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
epoch, i, batch_size/(time.time()-btic), name, acc))
btic = time.time()
name, acc = metric.get()
logging.info('[Epoch %d] training: %s=%f'%(epoch, name, acc))
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
name, val_acc = test(val_data, ctx)
logging.info('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc))
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
args = parse_args()
logging.info(args)
random.seed(args.seed)
ctx = [mx.gpu(i) for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()]
net = get_net()
train_data, val_data = get_data_rec(args.batch_size)
train(net, train_data, val_data, args.epochs, args.lr, args.momentum, args.wd,
ctx, args.kvstore, args.log_interval)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment