Skip to content

Instantly share code, notes, and snippets.

@mjamroz
Created January 26, 2020 19:20
Show Gist options
  • Save mjamroz/d6f4aacbf46442bace75e67b0464e2c5 to your computer and use it in GitHub Desktop.
Save mjamroz/d6f4aacbf46442bace75e67b0464e2c5 to your computer and use it in GitHub Desktop.
import mxnet as mx
import argparse, os
from matplotlib import pyplot as plt
from gluoncv.model_zoo import get_model
#mx.random.seed(42)
# CLI
def parse_args():
parser = argparse.ArgumentParser(description='Train a model for image classification.')
parser.add_argument('--classes', type=str, default=1000,
help='number of classes')
parser.add_argument('--rec-train', type=str, default='images_train.rec',
help='the training data')
parser.add_argument('--rec-train-idx', type=str, default='images_train.idx',
help='the index of training data')
parser.add_argument('--batch-size', type=int, default=32,
help='training batch size per device (CPU/GPU).')
parser.add_argument('--dtype', type=str, default='float32',
help='data type for training. default is float32')
parser.add_argument('--num-gpus', type=int, default=0,
help='number of gpus to use.')
parser.add_argument('--momentum', type=float, default=0.9,
help='momentum value for optimizer, default is 0.9.')
parser.add_argument('--wd', type=float, default=0.0001,
help='weight decay rate. default is 0.0001.')
parser.add_argument('--model', type=str, required=True,
help='type of model to use. see vision_model for options.')
parser.add_argument('--no-wd', action='store_true',
help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.')
opt = parser.parse_args()
return opt
def prep_net(classes, context, opt):
net = get_model(opt.model, pretrained=True)
with net.name_scope():
net.output = mx.gluon.nn.Dense(classes)
net.output.initialize(mx.init.Xavier(), ctx=context)
net.collect_params().reset_ctx(context)
if opt.no_wd:
for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
v.wd_mult = 0.0
net.hybridize(static_alloc=True, static_shape=True)
return net
class Learner():
def __init__(self, net, data_loader, ctx, opt):
"""
:param net: network (mx.gluon.Block)
:param data_loader: training data loader (mx.gluon.data.DataLoader)
:param ctx: context (mx.gpu or mx.cpu)
"""
self.net = net
self.opt = opt
self.data_loader = data_loader
self.ctx = ctx
#self.net.initialize(mx.init.Xavier(), ctx=self.ctx)
self.net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
self.loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss()
optimizer_params = {'learning_rate': .001, 'wd': self.opt.wd, 'momentum': self.opt.momentum}
if self.opt.dtype != 'float32':
optimizer_params['multi_precision'] = True
self.trainer = mx.gluon.Trainer(net.collect_params(), 'nag', optimizer_params)
def iteration(self, lr=None, take_step=True):
"""
:param lr: learning rate to use for iteration (float)
:param take_step: take trainer step to update weights (boolean)
:return: iteration loss (float)
"""
# Update learning rate if different this iteration
if lr and (lr != self.trainer.learning_rate):
self.trainer.set_learning_rate(lr)
# Get next batch, and move context (e.g. to GPU if set)
try:
bt = next(self.data_loader)
except StopIteration:
self.data_loader.reset()
bt = next(self.data_loader)
data = mx.gluon.utils.split_and_load(bt.data[0], ctx_list=self.ctx, batch_axis=0)
label = mx.gluon.utils.split_and_load(bt.label[0], ctx_list=self.ctx, batch_axis=0)
# Standard forward and backward pass
with mx.autograd.record():
outputs = [self.net(X.astype(self.opt.dtype, copy=False)) for X in data]
loss = [self.loss_fn(yhat, y.astype(self.opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
for l in loss:
l.backward()
# Update parameters
if take_step: self.trainer.step(data[0].shape[0])
# Set and return loss.
ls = 0.0
for l in loss:
ls += mx.nd.mean(l).asscalar()
self.iteration_loss = ls
return self.iteration_loss
def close(self):
# Close open iterator and associated workers
self.data_loader.shutdown()
class LRFinder():
def __init__(self, learner):
"""
:param learner: able to take single iteration with given learning rate and return loss
and save and load parameters of the network (Learner)
"""
self.learner = learner
def find(self, lr_start=1e-6, lr_multiplier=1.1, smoothing=0.3):
"""
:param lr_start: learning rate to start search (float)
:param lr_multiplier: factor the learning rate is multiplied by at each step of search (float)
:param smoothing: amount of smoothing applied to loss for stopping criteria (float)
:return: learning rate and loss pairs (list of (float, float) tuples)
"""
# Used to initialize weights; pass data, but don't take step.
# Would expect for new model with lazy weight initialization
self.learner.iteration(take_step=False)
# Used to initialize trainer (if no step has been taken)
if not self.learner.trainer._kv_initialized:
self.learner.trainer._init_kvstore()
# Store params and optimizer state for restore after lr_finder procedure
# Useful for applying the method partway through training, not just for initialization of lr.
self.learner.net.save_parameters("lr_finder.params")
self.learner.trainer.save_states("lr_finder.state")
lr = lr_start
self.results = [] # List of (lr, loss) tuples
stopping_criteria = LRFinderStoppingCriteria(smoothing)
while True:
# Run iteration, and block until loss is calculated.
loss = self.learner.iteration(lr)
self.results.append((lr, loss))
if stopping_criteria(loss):
break
lr = lr * lr_multiplier
# Restore params (as finder changed them)
self.learner.net.load_parameters("lr_finder.params", ctx=self.learner.ctx)
self.learner.trainer.load_states("lr_finder.state")
return self.results
def plot(self):
lrs = [e[0] for e in self.results]
losses = [e[1] for e in self.results]
plt.figure(figsize=(6,8))
plt.scatter(lrs, losses)
plt.xlabel("Learning Rate")
plt.ylabel("Loss")
plt.xscale('log')
plt.yscale('log')
axes = plt.gca()
axes.set_xlim([lrs[0], lrs[-1]])
y_lower = min(losses) * 0.8
y_upper = losses[0] * 4
axes.set_ylim([y_lower, y_upper])
plt.savefig("find_lr.png")
plt.show()
class LRFinderStoppingCriteria():
def __init__(self, smoothing=0.3, min_iter=20):
"""
:param smoothing: applied to running mean which is used for thresholding (float)
:param min_iter: minimum number of iterations before early stopping can occur (int)
"""
self.smoothing = smoothing
self.min_iter = min_iter
self.first_loss = None
self.running_mean = None
self.counter = 0
def __call__(self, loss):
"""
:param loss: from single iteration (float)
:return: indicator to stop (boolean)
"""
self.counter += 1
if self.first_loss is None:
self.first_loss = loss
if self.running_mean is None:
self.running_mean = loss
else:
self.running_mean = ((1 - self.smoothing) * loss) + (self.smoothing * self.running_mean)
return (self.running_mean > self.first_loss * 2) and (self.counter >= self.min_iter)
def get_data_rec(rec_train, rec_train_idx, batch_size):
rec_train = os.path.expanduser(rec_train)
rec_train_idx = os.path.expanduser(rec_train_idx)
input_size = 224
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
train_data = mx.io.ImageRecordIter(
path_imgrec = rec_train,
path_imgidx = rec_train_idx,
preprocess_threads = 4,
shuffle = True,
batch_size = batch_size,
data_shape = (3, input_size, input_size),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
mean_b = mean_rgb[2],
std_r = std_rgb[0],
std_g = std_rgb[1],
std_b = std_rgb[2],
rand_crop = True
)
return train_data
opt = parse_args()
context = [mx.gpu(i) for i in range(opt.num_gpus)] if opt.num_gpus > 0 else [mx.cpu()]
net = prep_net(opt.classes, context, opt)
train_data = get_data_rec(opt.rec_train, opt.rec_train_idx, opt.batch_size)
learner = Learner(net=net, data_loader=train_data, ctx=context, opt=opt)
lr_finder = LRFinder(learner)
lr_finder.find(lr_start=1e-6)
lr_finder.plot()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment