Skip to content

Instantly share code, notes, and snippets.

@msakai
Created February 1, 2020 00:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msakai/955a75dc4d196afb5d82538870b5aa3d to your computer and use it in GitHub Desktop.
Save msakai/955a75dc4d196afb5d82538870b5aa3d to your computer and use it in GitHub Desktop.
import argparse
import numpy
import torch
from torch import nn
import torch.nn.functional as F
import ignite
import chainer
from chainer import training
from chainer.training import extensions
import chainer_pytorch_migration as cpm
import chainer_pytorch_migration.ignite
import matplotlib
matplotlib.use('Agg')
# Network definition
class MLP(nn.Module):
def __init__(self, n_in, n_units, n_out):
super(MLP, self).__init__()
self.l1 = nn.Linear(n_in, n_units) # n_in -> n_units
self.l2 = nn.Linear(n_units, n_units) # n_units -> n_units
self.l3 = nn.Linear(n_units, n_out) # n_units -> n_out
def forward(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--device', '-d', type=str, default='cpu',
help='Device specifier. e.g. \'cpu\' or \'cuda:0\'')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', type=str,
help='Resume the training from snapshot')
parser.add_argument('--autoload', action='store_true',
help='Automatically load trainer snapshots in case'
' of preemption or other temporary system failure')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
device = torch.device(args.device)
print('Device: {}'.format(device))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = MLP(784, args.unit, 10)
model.to(device)
# Setup an optimizer
optimizer = torch.optim.Adam(model.parameters())
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
def collate_fn(minibatch):
xs = []
ys = []
for x, y in minibatch:
xs.append(x)
ys.append(y)
return torch.FloatTensor(xs), torch.LongTensor(ys)
train_loader = torch.utils.data.DataLoader(
train, shuffle=True, batch_size=args.batchsize, pin_memory=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(
test, shuffle=False, batch_size=args.batchsize, pin_memory=True, collate_fn=collate_fn)
# Set up a trainer
trainer = ignite.engine.create_supervised_trainer(
model, optimizer, F.cross_entropy, device=device)
# Evaluate the model with the test dataset for each epoch
evaluator = ignite.engine.create_supervised_evaluator(
model,
metrics={
'accuracy': ignite.metrics.Accuracy(),
'loss': ignite.metrics.Loss(F.cross_entropy),
},
device=device)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def validation(engine):
evaluator.run(test_loader)
# print('validation_loss', evaluator.state.metrics['loss'])
# print('validation_accuracy', evaluator.state.metrics['accuracy'])
chainer.reporter.report({
'validation_loss': evaluator.state.metrics['loss'],
'validation_accuracy': evaluator.state.metrics['accuracy']
})
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def report_loss(engine):
chainer.reporter.report({'loss': engine.state.output})
# if evaluator.state:
# chainer.reporter.report({
# 'validation_loss': evaluator.state.metrics['loss'],
# 'validation_accuracy': evaluator.state.metrics['accuracy'],
# })
optimizer.target = model
trainer.out = args.out
# Take a snapshot for each specified epoch
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
# Take a snapshot each ``frequency`` epoch, delete old stale
# snapshots and automatically load from snapshot files if any
# files are already resident at result directory.
cpm.ignite.add_trainer_extension(trainer, optimizer,
extensions.snapshot(n_retains=1, autoload=args.autoload), trigger=(frequency, 'epoch'))
# Write a log of evaluation statistics for each epoch
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.LogReport())
# Save two plot images to the result dir
# cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PlotReport(
# ['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PlotReport(
['loss', 'validation_loss'], 'epoch', file_name='loss.png'))
# cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PlotReport(
# ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PlotReport(
['validation_accuracy'], 'epoch', file_name='accuracy.png'))
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
# cpm.ignite.add_trainer_extension(extensions.PrintReport(
# ['epoch', 'main/loss', 'validation/main/loss',
# 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']),
# call_before_training=True)
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PrintReport(
['epoch', 'elapsed_time', 'loss', 'validation_loss', 'validation_accuracy']))
# Print a progress bar to stdout
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ProgressBar())
if args.resume is not None:
# Resume from a snapshot (Note: this loaded model is to be
# overwritten by --autoload option, autoloading snapshots, if
# any snapshots exist in output directory)
cpm.ignite.load_chainer_snapshot(trainer, optimizer, args.resume)
# Run the training
trainer.run(train_loader, max_epochs=args.epoch)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment