Skip to content

Instantly share code, notes, and snippets.

@tok41
Created April 23, 2019 08:32
Show Gist options
  • Save tok41/bfd2e5c6bcd8501886276c91058c008c to your computer and use it in GitHub Desktop.
Save tok41/bfd2e5c6bcd8501886276c91058c008c to your computer and use it in GitHub Desktop.
"""
Chocoball Detector
Training Code (SSD)
"""
import os
import numpy as np
import random
import copy
from PIL import Image
import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer.optimizer_hooks import WeightDecay
from chainer import training
from chainer.training import extensions
from chainer.training import triggers
from chainercv.links import SSD300
from chainercv.links.model.ssd import multibox_loss
from chainercv.links.model.ssd import GradientScaling
from chainercv.links.model.ssd import multibox_loss
from chainercv import transforms
from chainercv.links.model.ssd import random_crop_with_bbox_constraints
from chainercv.links.model.ssd import random_distort
from chainercv.links.model.ssd import resize_with_random_interpolation
from chainercv.extensions import DetectionVOCEvaluator
from utils.callbacks import Statistics
from utils.tensorboard import Tensorboard
from dataset import load_dataset_from_api
from dataset import load_classes
from dataset import DetectionDatasetFromAPI
# ----- Global Vars
nb_iterations = int(os.environ.get('N_ITER', 3)) # 120000
TRAIN_RATE = 0.8
BATCHSIZE = int(os.environ.get('N_BATCH', 5))
USE_GPU = int(os.environ.get('USE_GPU', '-1'))
ABEJA_TRAINING_RESULT_DIR = os.environ.get('ABEJA_TRAINING_RESULT_DIR', '.')
log_path = os.path.join(ABEJA_TRAINING_RESULT_DIR, 'logs')
# ----- Classes
class MultiboxTrainChain(chainer.Chain):
"""MultiboxTrainChain
https://github.com/chainer/chainercv/blob/master/examples/ssd/train.py
"""
def __init__(self, model, alpha=1, k=3):
super(MultiboxTrainChain, self).__init__()
with self.init_scope():
self.model = model
self.alpha = alpha
self.k = k
def __call__(self, imgs, gt_mb_locs, gt_mb_labels):
mb_locs, mb_confs = self.model(imgs)
loc_loss, conf_loss = multibox_loss(
mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
loss = loc_loss * self.alpha + conf_loss
chainer.reporter.report(
{'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
self)
return loss
class Transform(object):
def __init__(self, coder, size, mean):
# to send cpu, make a copy
self.coder = copy.copy(coder)
self.coder.to_cpu()
self.size = size
self.mean = mean
def __call__(self, in_data):
# There are five data augmentation steps
# 1. Color augmentation
# 2. Random expansion
# 3. Random cropping
# 4. Resizing with random interpolation
# 5. Random horizontal flipping
img, bbox, label = in_data
# 1. Color augmentation
img = random_distort(img)
# 2. Random expansion
if np.random.randint(2):
img, param = transforms.random_expand(
img, fill=self.mean, return_param=True)
bbox = transforms.translate_bbox(
bbox, y_offset=param['y_offset'], x_offset=param['x_offset'])
# 3. Random cropping
img, param = random_crop_with_bbox_constraints(
img, bbox, return_param=True)
bbox, param = transforms.crop_bbox(
bbox, y_slice=param['y_slice'], x_slice=param['x_slice'],
allow_outside_center=False, return_param=True)
label = label[param['index']]
# 4. Resizing with random interpolatation
_, H, W = img.shape
img = resize_with_random_interpolation(img, (self.size, self.size))
bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))
# 5. Random horizontal flipping
img, params = transforms.random_flip(
img, x_random=True, return_param=True)
bbox = transforms.flip_bbox(
bbox, (self.size, self.size), x_flip=params['x_flip'])
# Preparation for SSD network
img -= self.mean
mb_loc, mb_label = self.coder.encode(bbox, label)
return img, mb_loc, mb_label
def handler(context):
"""
Args
context:実行時のメタデータなど
datasets:データセットエイリアスをキー、
データセットIDを値に持つdict
Returns:
"""
dataset_alias = context.datasets
dataset_id = dataset_alias['chocoballs'] # {dataset_name:dataset_id}
print('ITER NUM : ', nb_iterations)
print('BATCH SIZE : ', BATCHSIZE)
main(dataset_id=dataset_id)
def main(dataset_id, organization_id=None, credential=None):
# dataset from ABEJA Platform
dataset = list(load_dataset_from_api(
dataset_id, organization_id, credential))
# Split dataset (TrainSet and TestSet)
N = len(dataset)
N_train = (int)(N*TRAIN_RATE)
idxs = list(np.arange(N))
random.shuffle(idxs)
dataset_train = [dataset[idx] for idx in idxs[:N_train]]
dataset_test = [dataset[idx] for idx in idxs[N_train:]]
print('dataset length, total:{}, train:{}, test:{}'.format(
N, len(dataset_train), len(dataset_test)))
# set-up model
classes = load_classes('classes.txt')
model = SSD300(
n_fg_class=len(classes),
pretrained_model='ssd300_voc0712_2017_06_06_extractor.npz')
print('Model : SSD300')
model.use_preset('evaluate')
# set chain
train_chain = MultiboxTrainChain(model)
if USE_GPU >= 0:
chainer.cuda.get_device_from_id(USE_GPU).use()
model.to_gpu()
# data iterator
trainval = DetectionDatasetFromAPI(dataset_train)
testval = DetectionDatasetFromAPI(dataset_test)
train = TransformDataset(trainval,
Transform(
model.coder, model.insize, model.mean))
train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE)
test_iter = chainer.iterators.SerialIterator(
testval, BATCHSIZE, repeat=False, shuffle=False)
# Optimizer
optimizer = chainer.optimizers.MomentumSGD()
optimizer.setup(train_chain)
for param in train_chain.params():
if param.name == 'b':
param.update_rule.add_hook(GradientScaling(2))
else:
param.update_rule.add_hook(WeightDecay(0.0005))
# set-up training
updater = training.updaters.StandardUpdater(
train_iter, optimizer, device=USE_GPU)
trainer = training.Trainer(
updater, (nb_iterations, 'iteration'),
out=ABEJA_TRAINING_RESULT_DIR)
trainer.extend(
extensions.ExponentialShift('lr', 0.1, init=1e-3),
trigger=triggers.ManualScheduleTrigger([80000, 100000], 'iteration'))
trainer.extend(
DetectionVOCEvaluator(
test_iter, model, use_07_metric=True,
label_names=classes),
trigger=(10000, 'iteration'))
log_interval = 1, 'iteration'
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
print_entries = ['iteration',
'main/loss', 'main/loss/loc', 'main/loss/conf',
'validation/main/map']
report_entries = ['epoch', 'iteration', 'lr',
'main/loss', 'main/loss/loc', 'main/loss/conf',
'validation/main/map']
trainer.extend(Statistics(report_entries, nb_iterations,
obs_key='iteration'), trigger=log_interval)
trainer.extend(Tensorboard(report_entries, out_dir=log_path))
trainer.extend(extensions.PrintReport(print_entries), trigger=log_interval)
trainer.extend(
extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
trigger=(nb_iterations, 'iteration'))
trainer.run()
if __name__ == '__main__':
from dotenv import load_dotenv
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='1704553036773')
args = parser.parse_args()
from abejacli.config import (
ABEJA_PLATFORM_USER_ID,
ABEJA_PLATFORM_TOKEN
)
credential = {
'user_id': ABEJA_PLATFORM_USER_ID,
'personal_access_token': ABEJA_PLATFORM_TOKEN
}
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
load_dotenv(dotenv_path)
ORG_ID = os.environ.get("ORGANIZATION_ID")
dataset_id = args.dataset
main(dataset_id=dataset_id, organization_id=ORG_ID, credential=credential)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment