Last active
February 7, 2020 06:22
-
-
Save msakai/06cf2f4ad937ba3215daee227695499e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import ignite | |
import chainer | |
from chainer import reporter | |
from chainer.training import extensions | |
import chainer_pytorch_migration as cpm | |
from chainer_pytorch_migration import chainermn | |
import chainer_pytorch_migration.ignite | |
class MLP(nn.Module): | |
def __init__(self, n_in, n_units, n_out): | |
super(MLP, self).__init__() | |
self.l1 = nn.Linear(n_in, n_units) # n_in -> n_units | |
self.l2 = nn.Linear(n_units, n_units) # n_units -> n_units | |
self.l3 = nn.Linear(n_units, n_out) # n_units -> n_out | |
def __call__(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
def main(): | |
parser = argparse.ArgumentParser(description='ChainerMN example: MNIST') | |
parser.add_argument('--batchsize', '-b', type=int, default=100, | |
help='Number of images in each mini-batch') | |
parser.add_argument('--communicator', type=str, | |
default='pure_nccl', help='Type of communicator') | |
parser.add_argument('--epoch', '-e', type=int, default=20, | |
help='Number of sweeps over the dataset to train') | |
parser.add_argument('--gpu', '-g', action='store_true', | |
help='Use GPU') | |
parser.add_argument('--chainerx', '-x', action='store_true', | |
default=False, help='Use ChainerX') | |
parser.add_argument('--out', '-o', default='result', | |
help='Directory to output the result') | |
parser.add_argument('--resume', '-r', default='', | |
help='Resume the training from snapshot') | |
parser.add_argument('--unit', '-u', type=int, default=1000, | |
help='Number of units') | |
args = parser.parse_args() | |
# Prepare ChainerMN communicator. | |
if args.gpu: | |
if args.communicator == 'naive': | |
print('Error: \'naive\' communicator does not support GPU.\n') | |
exit(-1) | |
comm = chainermn.create_communicator(args.communicator) | |
device = torch.device('cuda:{}'.format(comm.intra_rank)) | |
else: | |
if args.communicator != 'naive': | |
print('Warning: using naive communicator ' | |
'because only naive supports CPU-only execution') | |
comm = chainermn.create_communicator('naive') | |
device = torch.device('cpu') | |
if comm.rank == 0: | |
print('==========================================') | |
print('Num process (COMM_WORLD): {}'.format(comm.size)) | |
if args.gpu: | |
print('Using GPUs') | |
print('Using {} communicator'.format(args.communicator)) | |
print('Num unit: {}'.format(args.unit)) | |
print('Num Minibatch-size: {}'.format(args.batchsize)) | |
print('Num epoch: {}'.format(args.epoch)) | |
print('==========================================') | |
model = MLP(784, args.unit, 10) | |
model.to(device) | |
# Create a multi node optimizer from a standard Chainer optimizer. | |
optimizer = torch.optim.Adam(model.parameters()) | |
optimizer = chainermn.create_multi_node_optimizer(optimizer, comm) | |
# Split and distribute the dataset. Only worker 0 loads the whole dataset. | |
# Datasets of worker 0 are evenly split and distributed to all workers. | |
if comm.rank == 0: | |
train, test = chainer.datasets.get_mnist() | |
else: | |
train, test = None, None | |
train = chainermn.scatter_dataset(train, comm, shuffle=True) | |
test = chainermn.scatter_dataset(test, comm, shuffle=True) | |
def collate_fn(minibatch): | |
xs = [] | |
ys = [] | |
for x, y in minibatch: | |
xs.append(x) | |
ys.append(y) | |
return torch.FloatTensor(xs), torch.LongTensor(ys) | |
train_loader = torch.utils.data.DataLoader( | |
train, shuffle=True, batch_size=args.batchsize, pin_memory=True, | |
collate_fn=collate_fn) | |
test_loader = torch.utils.data.DataLoader( | |
test, shuffle=False, | |
batch_size=args.batchsize, pin_memory=True, collate_fn=collate_fn) | |
trainer = ignite.engine.create_supervised_trainer( | |
model, optimizer, F.cross_entropy, device=device) | |
# Create a multi node evaluator from a standard Chainer evaluator. | |
evaluator = ignite.engine.create_supervised_evaluator( | |
model, | |
metrics={ | |
'loss': ignite.metrics.Loss(F.cross_entropy), | |
'accuracy': ignite.metrics.Accuracy(), | |
}, | |
device=device) | |
# evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) | |
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED) | |
def validation(engine): | |
evaluator.run(test_loader) | |
reporter.report({ | |
'loss': evaluator.state.metrics['loss'], | |
'accuracy': evaluator.state.metrics['accuracy'] | |
}, ) | |
optimizer.target = model | |
trainer.out = "result" | |
snapshot = extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}') | |
replica_sets = [[0], range(1, comm.size)] | |
snapshot = chainermn.extensions.multi_node_snapshot(comm, snapshot, replica_sets) | |
cpm.ignite.add_trainer_extension(trainer, optimizer, snapshot, trigger=(1, 'epoch')) | |
# Some display and output extensions are necessary only for one worker. | |
# (Otherwise, there would just be repeated outputs.) | |
if comm.rank == 0: | |
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.LogReport()) | |
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ProgressBar()) | |
if args.resume: | |
cpm.ignite.load_chainer_snapshot(trainer, optimizer, args.resume) | |
trainer.run(train_loader, max_epochs=args.epoch) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment