Created
October 4, 2018 14:56
-
-
Save thorstenwagner/0c83fd5eab3d7de551b6d202b9de49fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import json | |
import os | |
from keras.utils import multi_gpu_model | |
import numpy as np | |
import time | |
from . import imagereader | |
from .frontend import YOLO | |
from .preprocessing import parse_annotation2 | |
from . import config_tools | |
from .my_multi_gpu_model import my_multi_gpu_model | |
import psutil | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
try: | |
os.environ["CUDA_VISIBLE_DEVICES"] | |
except KeyError: | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" | |
argparser = argparse.ArgumentParser( | |
description='Train crYOLO model on any dataset', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
argparser.add_argument( | |
'-c', | |
'--conf', | |
required=True, | |
help='path to configuration file') | |
argparser.add_argument( | |
'-w', | |
'--warmup', | |
type=int, | |
default=None, | |
help='Number of warmup epochs. If none, it tries to read it from the config file. If it is not defined in the ' | |
'config file, it is set to 0.') | |
argparser.add_argument( | |
'-p', | |
'--patch', | |
default=None, | |
type=int, | |
help='Number of patches. If none, it tries to read it from the config file. If it is not defined in the ' | |
'config file, it is set to 1.') | |
argparser.add_argument( | |
'-e', | |
'--early', | |
default=5, | |
type=int, | |
help='Early stop patience. If the validation loss did not improve longer than the early stop patience, ' | |
'the training is stopped.') | |
argparser.add_argument( | |
'-g', | |
'--gpu', | |
default=0, | |
type=int, | |
nargs="+", | |
help="Specifiy which gpu(s) should be used. Multiple GPUs are separated by a whitespace") | |
argparser.add_argument( | |
'--warm_restarts', | |
action="store_true", | |
help="Use warm restarts and cosine annealing during training") | |
def _main_(): | |
args = argparser.parse_args() | |
config_path = args.conf | |
with open(config_path) as config_buffer: | |
config = json.loads(config_buffer.read()) | |
if args.patch is not None: | |
num_patches = int(args.patch) | |
else: | |
num_patches = config_tools.get_number_patches(config) | |
early_stop = int(args.early) | |
warm_restarts = args.warm_restarts | |
if type(args.gpu) is list: | |
str_gpus = [str(entry) for entry in args.gpu] | |
num_gpus = len(args.gpu) | |
else: | |
str_gpus = str(args.gpu) | |
num_gpus = 1 | |
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str_gpus) | |
if args.warmup is not None: | |
warmup_epochs = int(args.warmup) | |
print("Read warmup by argument") | |
else: | |
if config['train']['warmup_epochs'] is not None: | |
warmup_epochs = config['train']['warmup_epochs'] | |
print("Read warmup by config") | |
else: | |
warmup_epochs = 0 | |
print("Set warmup to zero") | |
if early_stop < warmup_epochs: | |
early_stop = warmup_epochs | |
grid_w, grid_h = config_tools.get_gridcell_dimensions(config) | |
############################### | |
# Parse the annotations | |
############################### | |
# parse annotations of the training set | |
train_imgs, train_labels = parse_annotation2(ann_dir=config['train']['train_annot_folder'], | |
img_dir=config['train']['train_image_folder'], | |
grid_dims=(grid_w, grid_h, num_patches), | |
anchor_size=int(config['model']['anchors'][0])) | |
config['model']['labels'] = ['particle'] | |
# parse annotations of the validation set, if any, otherwise split the training set | |
if os.path.exists(config['valid']['valid_annot_folder']): | |
valid_imgs, valid_labels = parse_annotation2(config['valid']['valid_annot_folder'], | |
config['valid']['valid_image_folder'], (grid_w, grid_h, num_patches)) | |
if len(valid_imgs) == 0 or len(valid_labels) == 0 or len(valid_labels) != len(valid_imgs): | |
if len(valid_imgs) == 0: | |
print("No validation images were found. Invalid validation configuration. Check your config file.") | |
if len(valid_labels) == 0: | |
print("No validation labels were found. Invalid validation configuration. Check your config file.") | |
else: | |
np.random.seed(10) | |
train_valid_split = int(0.8*len(train_imgs)) | |
np.random.shuffle(train_imgs) | |
valid_imgs = train_imgs[train_valid_split:] | |
train_imgs = train_imgs[:train_valid_split] | |
print("Validation set:") | |
print([item['filename'] for item in valid_imgs]) | |
valid_imgs_paths = [] | |
valid_annot_paths = [] | |
for item in valid_imgs: | |
boxpath = item['boxpath'] | |
imgpath = item['filename'] | |
valid_imgs_paths.append(imgpath) | |
valid_annot_paths.append(boxpath) | |
runjson = {} | |
runjson["run"] = {} | |
runjson["run"]["valid_images"] = valid_imgs_paths | |
runjson["run"]["valid_annot"] = valid_annot_paths | |
if not os.path.exists("runfiles/"): | |
os.mkdir("runfiles/") | |
timestr = time.strftime("%Y%m%d-%H%M%S") | |
with open('runfiles/run_'+timestr+'.json', 'w') as outfile: | |
json.dump(runjson, outfile) | |
overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys())) | |
# Read first image and check the image depth. | |
img_first = imagereader.image_read(train_imgs[0]['filename']) | |
isgrey = False | |
if len(img_first.shape) == 2: | |
isgrey = True | |
elif np.all(img_first[:, :, 0] == img_first[:, :, 1]) and np.all(img_first[:, :, 0] == img_first[:, :, 2]): | |
isgrey = True | |
if isgrey: | |
depth = 1 | |
else: | |
depth = 3 | |
# As only on box size is expected, the anchor box size automatically | |
if len(config['model']['anchors'])>2: | |
# general model is used | |
img_width = 4096.0 / num_patches | |
img_height = 4096.0 / num_patches | |
cell_w = img_width / grid_w | |
cell_h = img_height / grid_h | |
anchors = np.array(config['model']['anchors'], dtype=float) | |
anchors[::2] = anchors[::2]/cell_w | |
anchors[1::2] = anchors[1::2]/cell_h | |
else: | |
# specifc model is used | |
img_width = float(img_first.shape[1]) / num_patches | |
img_height = float(img_first.shape[0]) / num_patches | |
cell_w = img_width / grid_w | |
cell_h = img_height / grid_h | |
box_width, box_height = config_tools.get_box_size(config) | |
anchor_width = 1.0 * box_width / cell_w | |
anchor_height = 1.0 * box_height / cell_h | |
anchors = [anchor_width, anchor_height] | |
if not config['train']['log_path']: | |
log_path = '~/logs/' | |
else: | |
log_path = config['train']['log_path'] | |
if not os.path.exists(log_path): | |
os.mkdir(log_path) | |
# Get overlap patches | |
overlap_patches = 0 | |
if 'overlap_patches' in config['model']: | |
overlap_patches = int(config['model']['overlap_patches']) | |
elif not len(config['model']['anchors'])>2: | |
overlap_patches = config['model']['anchors'][0] | |
############################### | |
# Construct the model | |
############################### | |
backend_weights = None | |
if 'backend_weights' in config['model']: | |
backend_weights = config['model']['backend_weights'] | |
if num_gpus > 1: | |
import tensorflow as tf | |
with tf.device('/cpu:0'): | |
yolo = YOLO(architecture=config['model']['architecture'], | |
input_size=config['model']['input_size'], | |
input_depth=depth, | |
labels=config['model']['labels'], | |
max_box_per_image=config['model']['max_box_per_image'], | |
anchors=anchors, | |
backend_weights=backend_weights) | |
else: | |
yolo = YOLO(architecture=config['model']['architecture'], | |
input_size=config['model']['input_size'], | |
input_depth=depth, | |
labels=config['model']['labels'], | |
max_box_per_image=config['model']['max_box_per_image'], | |
anchors=anchors, | |
backend_weights=backend_weights) | |
############################### | |
# Load the pretrained weights (if any) | |
############################### | |
if os.path.exists(config['train']['pretrained_weights']): | |
print("Loading pre-trained weights in", config['train']['pretrained_weights']) | |
yolo.load_weights(config['train']['pretrained_weights']) | |
# USE MULTIGPU | |
parallel_model=None | |
if num_gpus > 1: | |
parallel_model = multi_gpu_model(yolo.model, gpus=num_gpus) | |
############################### | |
# Start the training process | |
############################### | |
start = time.time() | |
parent_pid = os.getpid() | |
yolo.train(train_imgs=train_imgs, | |
valid_imgs=valid_imgs, | |
train_times=config['train']['train_times'], | |
valid_times=config['valid']['valid_times'], | |
nb_epoch=config['train']['nb_epoch'], | |
learning_rate=config['train']['learning_rate'], | |
batch_size=config['train']['batch_size'], | |
warmup_epochs=warmup_epochs, | |
object_scale=config['train']['object_scale'], | |
no_object_scale=config['train']['no_object_scale'], | |
coord_scale=config['train']['coord_scale'], | |
class_scale=config['train']['class_scale'], | |
saved_weights_name=config['train']['saved_weights_name'], | |
debug=config['train']['debug'], | |
log_path=log_path, | |
early_stop_thresh=early_stop, | |
num_patches=num_patches, | |
warm_restarts=warm_restarts, | |
overlap_patches=overlap_patches, | |
parallel_model=parallel_model) | |
end = time.time() | |
print("Time elapsed for training:", (end - start)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment