Skip to content

Instantly share code, notes, and snippets.

@ilkarman
Last active May 30, 2018 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ilkarman/37a4d5f44f25a4e023572a954a6b258f to your computer and use it in GitHub Desktop.
Save ilkarman/37a4d5f44f25a4e023572a954a6b258f to your computer and use it in GitHub Desktop.
Chainer multi-node training on Azure BatchAI
import argparse
import logging
import os
from os import path
import numpy as np
import pandas as pd
import multiprocessing
import random
from toolz import pipe
from timer import Timer
from PIL import Image
from chainercv import transforms
import chainer
import chainer.cuda
from chainer import training
from chainer.training import extensions
import resnet50
from mpi4py import MPI
import chainermn
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Distributed training settings
parser = argparse.ArgumentParser(
description='Chainer ResNet Example')
parser.add_argument('--communicator', default='hierarchical')
_WIDTH = 224
_HEIGHT = 224
_LR = 0.001
_EPOCHS = 1
_BATCHSIZE = 64
_IMAGENET_RGB_MEAN_CAFFE = np.array([123.68, 116.78, 103.94], dtype=np.float32)
_IMAGENET_SCALE_FACTOR_CAFFE = 0.017
args = parser.parse_args()
def _append_path_to(data_path, data_series):
return data_series.apply(lambda x: path.join(data_path, x))
def _load_training(data_dir):
train_df = pd.read_csv(path.join(data_dir, 'train.csv'))
return train_df.assign(filenames=_append_path_to(path.join(data_dir, 'train'),
train_df.filenames))
def _load_validation(data_dir):
train_df = pd.read_csv(path.join(data_dir, 'validation.csv'))
return train_df.assign(filenames=_append_path_to(path.join(data_dir, 'validation'),
train_df.filenames))
def _create_data_fn(train_path, test_path):
logger.info('Reading training data info')
train_df = _load_training(train_path)
logger.info('Reading validation data info')
validation_df = _load_validation(test_path)
# File-path
train_X = train_df['filenames'].values
validation_X = validation_df['filenames'].values
# One-hot encoded labels for torch
train_labels = train_df[['num_id']].values.ravel()
validation_labels = validation_df[['num_id']].values.ravel()
# Index starts from 0
train_labels -= 1
validation_labels -= 1
return train_X, train_labels, validation_X, validation_labels
class ImageNet(chainer.dataset.DatasetMixin):
def __init__(self, img_locs, labels, augmentation=None):
self.img_locs, self.labels = img_locs, labels
self.augmentation = augmentation
self.imagenet_mean = _IMAGENET_RGB_MEAN_CAFFE
self.imagenet_scaling = _IMAGENET_SCALE_FACTOR_CAFFE
logger.info("Loaded {} labels and {} images".format(len(self.labels), len(self.img_locs)))
def __len__(self):
return len(self.img_locs)
def get_example(self, idx):
im_file = self.img_locs[idx]
# RGB Image
im_rgb = Image.open(im_file)
im_rgb = im_rgb.convert('RGB')
im_rgb = self._apply_data_preprocessing(im_rgb)
label = self.labels[idx]
if self.augmentation is not None:
im_rgb = self._apply_data_augmentation(im_rgb)
else:
im_rgb = transforms.resize(im_rgb, size=(_HEIGHT, _WIDTH))
return np.array(im_rgb, dtype=np.float32), \
np.array(label, dtype=np.int32)
def _apply_data_preprocessing(self, rgb_im):
# Array
im = np.asarray(rgb_im, dtype=np.float32)
# (w, h, c) to (c, h, w)
im = im.transpose(2, 0, 1)
# Caffe normalisation
im -= self.imagenet_mean[:, None, None]
im *= self.imagenet_scaling
return im
def _apply_data_augmentation(self, im):
im = transforms.random_crop(im, size=(_HEIGHT, _WIDTH))
im = transforms.random_flip(im)
return im
class TestModeEvaluator(extensions.Evaluator):
def evaluate(self):
model = self.get_target('main')
model.train = False
ret = super(TestModeEvaluator, self).evaluate()
model.train = True
return ret
def main():
# Prepare ChainerMN communicator.
comm = chainermn.create_communicator(args.communicator)
device = comm.intra_rank
if comm.mpi_comm.rank == 0:
print('==========================================')
print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
print('Using {} communicator'.format(args.communicator))
print('Num Minibatch-size: {}'.format(_BATCHSIZE))
print('Num epoch: {}'.format(_EPOCHS))
print('==========================================')
model = resnet50.ResNet50()
if device >= 0:
chainer.cuda.get_device(device).use()
model.to_gpu()
# Create a multi node optimizer from a standard Chainer optimizer.
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.MomentumSGD(lr=_LR, momentum=0.9), comm)
optimizer.setup(model)
# Split and distribute the dataset. Only worker 0 loads the whole dataset.
# Datasets of worker 0 are evenly split and distributed to all workers.
if comm.rank == 0:
train_X, train_y, valid_X, valid_y = _create_data_fn(os.getenv('AZ_BATCHAI_INPUT_TRAIN'),
os.getenv('AZ_BATCHAI_INPUT_TEST'))
# For now some size issue for random-crop
train = ImageNet(train_X, train_y)
val = ImageNet(valid_X, valid_y)
else:
train = None
val = None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
val = chainermn.scatter_dataset(val, comm)
# Check if chainer.iterators.MultiprocessIterator can be used
#train_iter = chainer.iterators.SerialIterator(train, _BATCHSIZE)
#val_iter = chainer.iterators.SerialIterator(val, _BATCHSIZE, repeat=False)
#multiprocessing.set_start_method('forkserver')
train_iter = chainer.iterators.MultiprocessIterator(train, _BATCHSIZE, n_processes=24)
val_iter = chainer.iterators.MultiprocessIterator(val, _BATCHSIZE, repeat=False, n_processes=24)
# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (_EPOCHS, 'epoch'))
# No checkpointing temp
val_interval = (1, 'epoch')
log_interval = (1, 'epoch')
# Create a multi node evaluator from an evaluator.
evaluator = TestModeEvaluator(val_iter, model, device=device)
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
trainer.extend(evaluator, trigger=val_interval)
# Some display and output extensions are necessary only for one worker.
# (Otherwise, there would just be repeated outputs.)
if comm.rank == 0:
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(extensions.PrintReport([
'epoch', 'iteration', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
]), trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.run()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment