Skip to content

Instantly share code, notes, and snippets.

@mlzxy
Forked from zhreshold/train_imagenet.py
Created April 28, 2018 04:19
Show Gist options
  • Save mlzxy/05806018e100a718f687ffd0b157cda1 to your computer and use it in GitHub Desktop.
Save mlzxy/05806018e100a718f687ffd0b157cda1 to your computer and use it in GitHub Desktop.
Train imagenet using gluon
import argparse, time
import logging
logging.basicConfig(level=logging.INFO)
fh = logging.FileHandler('training.log')
logger = logging.getLogger()
logger.addHandler(fh)
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import vision
from mxnet import autograd as ag
class DummyIter(mx.io.DataIter):
def __init__(self, batch_size, data_shape, batches = 5):
super(DummyIter, self).__init__(batch_size)
self.data_shape = (batch_size,) + data_shape
self.label_shape = (batch_size,)
self.provide_data = [('data', self.data_shape)]
self.provide_label = [('softmax_label', self.label_shape)]
self.batch = mx.io.DataBatch(data=[mx.nd.zeros(self.data_shape)],
label=[mx.nd.zeros(self.label_shape)])
self._batches = 0
self.batches = batches
def next(self):
if self._batches < self.batches:
self._batches += 1
return self.batch
else:
self._batches = 0
raise StopIteration
def dummy_iterator(batch_size, data_shape):
return DummyIter(batch_size, data_shape), DummyIter(batch_size, data_shape)
# CLI
parser = argparse.ArgumentParser(description='Train a model for imagenet classification')
parser.add_argument('--train-rec', type=str, required=True,
help='training record file')
parser.add_argument('--val-rec', type=str, required=True,
help='validation record file')
parser.add_argument('--train-idx', type=str, required=True,
help='train index file')
parser.add_argument('--gpus', type=int, default=0,
help='number of gpus to use')
parser.add_argument('--epochs', type=int, default=120,
help='number of total epochs')
parser.add_argument('--batch-size', type=int, default=256,
help='batch size')
parser.add_argument('--lr', type=float, default=0.1,
help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9,
help='momentum')
parser.add_argument('--wd', type=float, default=1e-4,
help='weight decay')
parser.add_argument('--start-epoch', type=int, default=0,
help='starting epoch')
parser.add_argument('--resume', type=str, default='',
help='path to checkpoint')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123.')
parser.add_argument('--benchmark', action='store_true',
help='whether to run benchmark.')
parser.add_argument('--mode', type=str,
help='mode in which to train the model. options are symbolic, imperative, hybrid')
parser.add_argument('--iter', type=str,
help='type of iterator to use, cc to use .')
parser.add_argument('--model', type=str, required=True,
help='type of model to use. see vision_model for options.')
parser.add_argument('--use_thumbnail', action='store_true',
help='use thumbnail or not in resnet. default is false.')
parser.add_argument('--batch-norm', action='store_true',
help='enable batch normalization or not in vgg. default is false.')
parser.add_argument('--pretrained', action='store_true',
help='enable using pretrained model from gluon.')
parser.add_argument('--log-interval', type=int, default=50,
help='Number of batches to wait before logging.')
args = parser.parse_args()
logging.info(str(args))
mx.random.seed(args.seed)
ctx = [mx.gpu(i) for i in range(args.gpus)] if args.gpus > 0 else [mx.cpu()]
kwargs = {'ctx': ctx, 'pretrained': args.pretrained, 'classes': 1000}
if args.model.startswith('resnet'):
kwargs['thumbnail'] = args.use_thumbnail
elif args.model.startswith('vgg'):
kwargs['batch_norm'] = args.batch_norm
net = vision.get_model(args.model, **kwargs)
data_shape = (3, 224, 224)
if not args.benchmark:
if args.iter == 'cc':
train_iter = mx.io.ImageRecordIter(path_imgrec=args.train_rec, data_shape=data_shape,
shuffle=True, mean_r=123.68, mean_g=116.28, mean_b=103.53,
std_r=58.395, std_g=57.12, std_b=57.375,
batch_size=args.batch_size, rand_crop=True,
max_crop_size=480, min_crop_size=38, rand_mirror=True)
val_iter = mx.io.ImageRecordIter(path_imgrec=args.val_rec, data_shape=data_shape,
shuffle=False, mean_r=123.68, mean_g=116.28, mean_b=103.53,
std_r=58.395, std_g=57.12, std_b=57.375,
batch_size=args.batch_size, rand_crop=False, rand_mirror=False)
else:
train_iter = mx.image.ImageIter(args.batch_size, data_shape, path_imgrec=args.train_rec,
path_imgidx=args.train_idx, shuffle=True, mean=True,
std=True, rand_resize=True, rand_crop=True, rand_mirror=True)
# train_iter = mx.io.PrefetchingIter(train_iter)
val_iter = mx.image.ImageIter(args.batch_size, data_shape, path_imgrec=args.val_rec,
shuffle=False, mean=True, std=True, resize=256)
# val_iter = mx.io.PrefetchingIter(val_iter)
else:
train_iter, val_iter = dummy_iterator(args.batch_size, data_shape)
def validate():
metric = mx.metric.CompositeEvalMetric(['acc', mx.metric.TopKAccuracy(5)])
val_iter.reset()
for batch in val_iter:
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
for x in data:
outputs.append(net(x))
metric.update(label, outputs)
return metric.get()
def train(epochs, start_epoch):
if args.resume:
start_epoch = int(args.resume)
net.load_params('imagenet-%s-%d.params' % (args.model, start_epoch), ctx=ctx)
logging.info('loaded from epoch %d', start_epoch)
elif not args.pretrained:
net.initialize(mx.init.Xavier(factor_type='out', magnitude=2), ctx=ctx)
optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}
trainer = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params)
metric = mx.metric.CompositeEvalMetric(['acc', mx.metric.TopKAccuracy(5)])
loss = gluon.loss.SoftmaxCrossEntropyLoss()
for epoch in range(start_epoch, epochs):
phase = int(start_epoch / 30)
if phase > 0:
optimizer_params['learning_rate'] = args.lr / (10 ** phase)
logging.info('Reduce learning rate to %f', optimizer_params['learning_rate'])
elif epoch % 30 == 0 and epoch > 0:
optimizer_params['learning_rate'] /= 10.
# optimizer_params['learning_rate'] = args.lr / (10 ** (int(epoch / 30)))
logging.info('Reduce learning rate to %f', optimizer_params['learning_rate'])
trainer = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params)
tic = time.time()
train_iter.reset()
metric.reset()
btic = time.time()
for i, batch in enumerate(train_iter):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
Ls = []
with ag.record():
for x, y in zip(data, label):
z = net(x)
# L = loss(z, y)
L = mx.nd.SoftmaxOutput(z, y)
# store the loss and do backward after we have done forward
# on all GPUs for better speed on multiple GPUs.
Ls.append(L)
outputs.append(z)
for L in Ls:
L.backward()
trainer.step(batch.data[0].shape[0])
metric.update(label, outputs)
if args.log_interval and not (i+1)%args.log_interval:
name, acc = metric.get()
logging.info('[Epoch %d Batch %d] speed: %f samples/s, training: %s=%f, %s=%f'%(
epoch, i, args.batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
btic = time.time()
name, acc = metric.get()
logging.info('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name[0], acc[0], name[1], acc[1]))
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
name, val_acc = validate()
logging.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))
net.save_params('imagenet-%s-%d.params' % (args.model, epoch + 1))
if __name__ == '__main__':
if args.mode == 'symbolic':
data = mx.sym.var('data')
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=ctx)
optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}
mod.fit(train_iter, val_iter, num_epoch=args.epochs,
batch_end_callback = mx.callback.Speedometer(args.batch_size, 1),
optimizer='sgd',
optimizer_params=optimizer_params,
epoch_end_callback = mx.callback.do_checkpoint(args.model),
initializer=mx.init.Xavier(factor_type='out', magnitude=2),
eval_metric=['acc', mx.metric.TopKAccuracy(5)],
validation_metric=['acc', mx.metric.TopKAccuracy(5)])
else:
if args.mode == 'hybrid':
net.hybridize()
train(args.epochs, args.start_epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment