Created
April 23, 2019 08:32
-
-
Save tok41/bfd2e5c6bcd8501886276c91058c008c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Chocoball Detector | |
Training Code (SSD) | |
""" | |
import os | |
import numpy as np | |
import random | |
import copy | |
from PIL import Image | |
import chainer | |
from chainer.datasets import ConcatenatedDataset | |
from chainer.datasets import TransformDataset | |
from chainer.optimizer_hooks import WeightDecay | |
from chainer import training | |
from chainer.training import extensions | |
from chainer.training import triggers | |
from chainercv.links import SSD300 | |
from chainercv.links.model.ssd import multibox_loss | |
from chainercv.links.model.ssd import GradientScaling | |
from chainercv.links.model.ssd import multibox_loss | |
from chainercv import transforms | |
from chainercv.links.model.ssd import random_crop_with_bbox_constraints | |
from chainercv.links.model.ssd import random_distort | |
from chainercv.links.model.ssd import resize_with_random_interpolation | |
from chainercv.extensions import DetectionVOCEvaluator | |
from utils.callbacks import Statistics | |
from utils.tensorboard import Tensorboard | |
from dataset import load_dataset_from_api | |
from dataset import load_classes | |
from dataset import DetectionDatasetFromAPI | |
# ----- Global Vars | |
nb_iterations = int(os.environ.get('N_ITER', 3)) # 120000 | |
TRAIN_RATE = 0.8 | |
BATCHSIZE = int(os.environ.get('N_BATCH', 5)) | |
USE_GPU = int(os.environ.get('USE_GPU', '-1')) | |
ABEJA_TRAINING_RESULT_DIR = os.environ.get('ABEJA_TRAINING_RESULT_DIR', '.') | |
log_path = os.path.join(ABEJA_TRAINING_RESULT_DIR, 'logs') | |
# ----- Classes | |
class MultiboxTrainChain(chainer.Chain): | |
"""MultiboxTrainChain | |
https://github.com/chainer/chainercv/blob/master/examples/ssd/train.py | |
""" | |
def __init__(self, model, alpha=1, k=3): | |
super(MultiboxTrainChain, self).__init__() | |
with self.init_scope(): | |
self.model = model | |
self.alpha = alpha | |
self.k = k | |
def __call__(self, imgs, gt_mb_locs, gt_mb_labels): | |
mb_locs, mb_confs = self.model(imgs) | |
loc_loss, conf_loss = multibox_loss( | |
mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k) | |
loss = loc_loss * self.alpha + conf_loss | |
chainer.reporter.report( | |
{'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss}, | |
self) | |
return loss | |
class Transform(object): | |
def __init__(self, coder, size, mean): | |
# to send cpu, make a copy | |
self.coder = copy.copy(coder) | |
self.coder.to_cpu() | |
self.size = size | |
self.mean = mean | |
def __call__(self, in_data): | |
# There are five data augmentation steps | |
# 1. Color augmentation | |
# 2. Random expansion | |
# 3. Random cropping | |
# 4. Resizing with random interpolation | |
# 5. Random horizontal flipping | |
img, bbox, label = in_data | |
# 1. Color augmentation | |
img = random_distort(img) | |
# 2. Random expansion | |
if np.random.randint(2): | |
img, param = transforms.random_expand( | |
img, fill=self.mean, return_param=True) | |
bbox = transforms.translate_bbox( | |
bbox, y_offset=param['y_offset'], x_offset=param['x_offset']) | |
# 3. Random cropping | |
img, param = random_crop_with_bbox_constraints( | |
img, bbox, return_param=True) | |
bbox, param = transforms.crop_bbox( | |
bbox, y_slice=param['y_slice'], x_slice=param['x_slice'], | |
allow_outside_center=False, return_param=True) | |
label = label[param['index']] | |
# 4. Resizing with random interpolatation | |
_, H, W = img.shape | |
img = resize_with_random_interpolation(img, (self.size, self.size)) | |
bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size)) | |
# 5. Random horizontal flipping | |
img, params = transforms.random_flip( | |
img, x_random=True, return_param=True) | |
bbox = transforms.flip_bbox( | |
bbox, (self.size, self.size), x_flip=params['x_flip']) | |
# Preparation for SSD network | |
img -= self.mean | |
mb_loc, mb_label = self.coder.encode(bbox, label) | |
return img, mb_loc, mb_label | |
def handler(context): | |
""" | |
Args | |
context:実行時のメタデータなど | |
datasets:データセットエイリアスをキー、 | |
データセットIDを値に持つdict | |
Returns: | |
""" | |
dataset_alias = context.datasets | |
dataset_id = dataset_alias['chocoballs'] # {dataset_name:dataset_id} | |
print('ITER NUM : ', nb_iterations) | |
print('BATCH SIZE : ', BATCHSIZE) | |
main(dataset_id=dataset_id) | |
def main(dataset_id, organization_id=None, credential=None): | |
# dataset from ABEJA Platform | |
dataset = list(load_dataset_from_api( | |
dataset_id, organization_id, credential)) | |
# Split dataset (TrainSet and TestSet) | |
N = len(dataset) | |
N_train = (int)(N*TRAIN_RATE) | |
idxs = list(np.arange(N)) | |
random.shuffle(idxs) | |
dataset_train = [dataset[idx] for idx in idxs[:N_train]] | |
dataset_test = [dataset[idx] for idx in idxs[N_train:]] | |
print('dataset length, total:{}, train:{}, test:{}'.format( | |
N, len(dataset_train), len(dataset_test))) | |
# set-up model | |
classes = load_classes('classes.txt') | |
model = SSD300( | |
n_fg_class=len(classes), | |
pretrained_model='ssd300_voc0712_2017_06_06_extractor.npz') | |
print('Model : SSD300') | |
model.use_preset('evaluate') | |
# set chain | |
train_chain = MultiboxTrainChain(model) | |
if USE_GPU >= 0: | |
chainer.cuda.get_device_from_id(USE_GPU).use() | |
model.to_gpu() | |
# data iterator | |
trainval = DetectionDatasetFromAPI(dataset_train) | |
testval = DetectionDatasetFromAPI(dataset_test) | |
train = TransformDataset(trainval, | |
Transform( | |
model.coder, model.insize, model.mean)) | |
train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE) | |
test_iter = chainer.iterators.SerialIterator( | |
testval, BATCHSIZE, repeat=False, shuffle=False) | |
# Optimizer | |
optimizer = chainer.optimizers.MomentumSGD() | |
optimizer.setup(train_chain) | |
for param in train_chain.params(): | |
if param.name == 'b': | |
param.update_rule.add_hook(GradientScaling(2)) | |
else: | |
param.update_rule.add_hook(WeightDecay(0.0005)) | |
# set-up training | |
updater = training.updaters.StandardUpdater( | |
train_iter, optimizer, device=USE_GPU) | |
trainer = training.Trainer( | |
updater, (nb_iterations, 'iteration'), | |
out=ABEJA_TRAINING_RESULT_DIR) | |
trainer.extend( | |
extensions.ExponentialShift('lr', 0.1, init=1e-3), | |
trigger=triggers.ManualScheduleTrigger([80000, 100000], 'iteration')) | |
trainer.extend( | |
DetectionVOCEvaluator( | |
test_iter, model, use_07_metric=True, | |
label_names=classes), | |
trigger=(10000, 'iteration')) | |
log_interval = 1, 'iteration' | |
trainer.extend(extensions.LogReport(trigger=log_interval)) | |
trainer.extend(extensions.observe_lr(), trigger=log_interval) | |
print_entries = ['iteration', | |
'main/loss', 'main/loss/loc', 'main/loss/conf', | |
'validation/main/map'] | |
report_entries = ['epoch', 'iteration', 'lr', | |
'main/loss', 'main/loss/loc', 'main/loss/conf', | |
'validation/main/map'] | |
trainer.extend(Statistics(report_entries, nb_iterations, | |
obs_key='iteration'), trigger=log_interval) | |
trainer.extend(Tensorboard(report_entries, out_dir=log_path)) | |
trainer.extend(extensions.PrintReport(print_entries), trigger=log_interval) | |
trainer.extend( | |
extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'), | |
trigger=(nb_iterations, 'iteration')) | |
trainer.run() | |
if __name__ == '__main__': | |
from dotenv import load_dotenv | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset', default='1704553036773') | |
args = parser.parse_args() | |
from abejacli.config import ( | |
ABEJA_PLATFORM_USER_ID, | |
ABEJA_PLATFORM_TOKEN | |
) | |
credential = { | |
'user_id': ABEJA_PLATFORM_USER_ID, | |
'personal_access_token': ABEJA_PLATFORM_TOKEN | |
} | |
dotenv_path = os.path.join(os.path.dirname(__file__), '.env') | |
load_dotenv(dotenv_path) | |
ORG_ID = os.environ.get("ORGANIZATION_ID") | |
dataset_id = args.dataset | |
main(dataset_id=dataset_id, organization_id=ORG_ID, credential=credential) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment