Skip to content

Instantly share code, notes, and snippets.

@corochann
Last active July 15, 2022 15:09
Show Gist options
  • Save corochann/22ae506123805e1ddece529d8db5b692 to your computer and use it in GitHub Desktop.
Save corochann/22ae506123805e1ddece529d8db5b692 to your computer and use it in GitHub Desktop.
Manual Scheduling of optimizer's learning rate with Chainer trainer
from __future__ import print_function
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.optimizers import MomentumSGD, Adam
from chainer.training import extensions, Trainer
from chainer import serializers
opts = {
'mom': MomentumSGD,
'adam': Adam,
}
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
self.l1 = L.Linear(None, n_units)
self.l2 = L.Linear(None, n_units)
self.l3 = L.Linear(None, n_out)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def schedule_optimizer_value(epoch_list, value_list, optimizer_name='main', attr_name='lr'):
"""Set optimizer's hyperparameter according to value_list, scheduled on epoch_list.
Example usage:
trainer.extend(schedule_optimizer_value([2, 4, 7], [0.008, 0.006, 0.002]))
"""
if isinstance(epoch_list, list):
assert len(epoch_list) == len(value_list)
else:
assert isinstance(epoch_list, float) or isinstance(epoch_list, int)
assert isinstance(value_list, float) or isinstance(value_list, int)
epoch_list = [epoch_list, ]
value_list = [value_list, ]
trigger = chainer.training.triggers.ManualScheduleTrigger(epoch_list, 'epoch')
count = 0
@chainer.training.extension.make_extension(trigger=trigger)
def set_value(trainer: Trainer):
nonlocal count
optimizer = trainer.updater.get_optimizer(optimizer_name)
setattr(optimizer, attr_name, value_list[count])
count += 1
return set_value
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', default='result',
help='Directory to output the result')
parser.add_argument('--opt', '-o', default='mom',
help='Optimizer')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=50,
help='Number of units')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = MLP(args.unit, 10)
classifier_model = L.Classifier(model)
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
classifier_model.to_gpu() # Copy the model to the GPU
# Setup an optimizer
optimizer = opts[args.opt]()
optimizer.setup(classifier_model)
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
trainer.extend(extensions.Evaluator(test_iter, classifier_model, device=args.gpu))
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
trainer.extend(extensions.LogReport())
# --- observe_lr ---
trainer.extend(extensions.observe_lr())
# --- Manually schedule learning rate ---
# --- Example usage: schedule learning rate as follows ---
# lr = 0.008 at epoch 2
# lr = 0.006 at epoch 4
# lr = 0.002 at epoch 7
trainer.extend(schedule_optimizer_value([2, 4, 7], [0.008, 0.006, 0.002]))
# trainer.extend(schedule_optimizer_value(3.5, 0.008))
# trainer.extend(schedule_optimizer_value(3.5, 0.008, attr_name='alpha')) # when optimizer is Adam
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'lr', 'elapsed_time']))
# Plot graph for loss for each epoch
trainer.extend(extensions.PlotReport(
['main/loss', 'validation/main/loss'],
x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
x_key='epoch',
file_name='accuracy.png'))
#trainer.extend(extensions.ProgressBar())
if args.resume:
# Resume from a snapshot
serializers.load_npz(args.resume, trainer)
# Run the training
trainer.run()
serializers.save_npz('{}/mlp.model'.format(args.out), model)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment