Skip to content

Instantly share code, notes, and snippets.

@msakai
Last active February 7, 2020 06:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msakai/06cf2f4ad937ba3215daee227695499e to your computer and use it in GitHub Desktop.
Save msakai/06cf2f4ad937ba3215daee227695499e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import argparse
import torch
from torch import nn
import torch.nn.functional as F
import ignite
import chainer
from chainer import reporter
from chainer.training import extensions
import chainer_pytorch_migration as cpm
from chainer_pytorch_migration import chainermn
import chainer_pytorch_migration.ignite
class MLP(nn.Module):
def __init__(self, n_in, n_units, n_out):
super(MLP, self).__init__()
self.l1 = nn.Linear(n_in, n_units) # n_in -> n_units
self.l2 = nn.Linear(n_units, n_units) # n_units -> n_units
self.l3 = nn.Linear(n_units, n_out) # n_units -> n_out
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--communicator', type=str,
default='pure_nccl', help='Type of communicator')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', action='store_true',
help='Use GPU')
parser.add_argument('--chainerx', '-x', action='store_true',
default=False, help='Use ChainerX')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
# Prepare ChainerMN communicator.
if args.gpu:
if args.communicator == 'naive':
print('Error: \'naive\' communicator does not support GPU.\n')
exit(-1)
comm = chainermn.create_communicator(args.communicator)
device = torch.device('cuda:{}'.format(comm.intra_rank))
else:
if args.communicator != 'naive':
print('Warning: using naive communicator '
'because only naive supports CPU-only execution')
comm = chainermn.create_communicator('naive')
device = torch.device('cpu')
if comm.rank == 0:
print('==========================================')
print('Num process (COMM_WORLD): {}'.format(comm.size))
if args.gpu:
print('Using GPUs')
print('Using {} communicator'.format(args.communicator))
print('Num unit: {}'.format(args.unit))
print('Num Minibatch-size: {}'.format(args.batchsize))
print('Num epoch: {}'.format(args.epoch))
print('==========================================')
model = MLP(784, args.unit, 10)
model.to(device)
# Create a multi node optimizer from a standard Chainer optimizer.
optimizer = torch.optim.Adam(model.parameters())
optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)
# Split and distribute the dataset. Only worker 0 loads the whole dataset.
# Datasets of worker 0 are evenly split and distributed to all workers.
if comm.rank == 0:
train, test = chainer.datasets.get_mnist()
else:
train, test = None, None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)
def collate_fn(minibatch):
xs = []
ys = []
for x, y in minibatch:
xs.append(x)
ys.append(y)
return torch.FloatTensor(xs), torch.LongTensor(ys)
train_loader = torch.utils.data.DataLoader(
train, shuffle=True, batch_size=args.batchsize, pin_memory=True,
collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(
test, shuffle=False,
batch_size=args.batchsize, pin_memory=True, collate_fn=collate_fn)
trainer = ignite.engine.create_supervised_trainer(
model, optimizer, F.cross_entropy, device=device)
# Create a multi node evaluator from a standard Chainer evaluator.
evaluator = ignite.engine.create_supervised_evaluator(
model,
metrics={
'loss': ignite.metrics.Loss(F.cross_entropy),
'accuracy': ignite.metrics.Accuracy(),
},
device=device)
# evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def validation(engine):
evaluator.run(test_loader)
reporter.report({
'loss': evaluator.state.metrics['loss'],
'accuracy': evaluator.state.metrics['accuracy']
}, )
optimizer.target = model
trainer.out = "result"
snapshot = extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')
replica_sets = [[0], range(1, comm.size)]
snapshot = chainermn.extensions.multi_node_snapshot(comm, snapshot, replica_sets)
cpm.ignite.add_trainer_extension(trainer, optimizer, snapshot, trigger=(1, 'epoch'))
# Some display and output extensions are necessary only for one worker.
# (Otherwise, there would just be repeated outputs.)
if comm.rank == 0:
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.LogReport())
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ProgressBar())
if args.resume:
cpm.ignite.load_chainer_snapshot(trainer, optimizer, args.resume)
trainer.run(train_loader, max_epochs=args.epoch)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment