Skip to content

Instantly share code, notes, and snippets.

@thorstenwagner
Created October 4, 2018 14:56
Show Gist options
  • Save thorstenwagner/0c83fd5eab3d7de551b6d202b9de49fa to your computer and use it in GitHub Desktop.
Save thorstenwagner/0c83fd5eab3d7de551b6d202b9de49fa to your computer and use it in GitHub Desktop.
from __future__ import print_function
import argparse
import json
import os
from keras.utils import multi_gpu_model
import numpy as np
import time
from . import imagereader
from .frontend import YOLO
from .preprocessing import parse_annotation2
from . import config_tools
from .my_multi_gpu_model import my_multi_gpu_model
import psutil
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
try:
os.environ["CUDA_VISIBLE_DEVICES"]
except KeyError:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
argparser = argparse.ArgumentParser(
description='Train crYOLO model on any dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument(
'-c',
'--conf',
required=True,
help='path to configuration file')
argparser.add_argument(
'-w',
'--warmup',
type=int,
default=None,
help='Number of warmup epochs. If none, it tries to read it from the config file. If it is not defined in the '
'config file, it is set to 0.')
argparser.add_argument(
'-p',
'--patch',
default=None,
type=int,
help='Number of patches. If none, it tries to read it from the config file. If it is not defined in the '
'config file, it is set to 1.')
argparser.add_argument(
'-e',
'--early',
default=5,
type=int,
help='Early stop patience. If the validation loss did not improve longer than the early stop patience, '
'the training is stopped.')
argparser.add_argument(
'-g',
'--gpu',
default=0,
type=int,
nargs="+",
help="Specifiy which gpu(s) should be used. Multiple GPUs are separated by a whitespace")
argparser.add_argument(
'--warm_restarts',
action="store_true",
help="Use warm restarts and cosine annealing during training")
def _main_():
args = argparser.parse_args()
config_path = args.conf
with open(config_path) as config_buffer:
config = json.loads(config_buffer.read())
if args.patch is not None:
num_patches = int(args.patch)
else:
num_patches = config_tools.get_number_patches(config)
early_stop = int(args.early)
warm_restarts = args.warm_restarts
if type(args.gpu) is list:
str_gpus = [str(entry) for entry in args.gpu]
num_gpus = len(args.gpu)
else:
str_gpus = str(args.gpu)
num_gpus = 1
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str_gpus)
if args.warmup is not None:
warmup_epochs = int(args.warmup)
print("Read warmup by argument")
else:
if config['train']['warmup_epochs'] is not None:
warmup_epochs = config['train']['warmup_epochs']
print("Read warmup by config")
else:
warmup_epochs = 0
print("Set warmup to zero")
if early_stop < warmup_epochs:
early_stop = warmup_epochs
grid_w, grid_h = config_tools.get_gridcell_dimensions(config)
###############################
# Parse the annotations
###############################
# parse annotations of the training set
train_imgs, train_labels = parse_annotation2(ann_dir=config['train']['train_annot_folder'],
img_dir=config['train']['train_image_folder'],
grid_dims=(grid_w, grid_h, num_patches),
anchor_size=int(config['model']['anchors'][0]))
config['model']['labels'] = ['particle']
# parse annotations of the validation set, if any, otherwise split the training set
if os.path.exists(config['valid']['valid_annot_folder']):
valid_imgs, valid_labels = parse_annotation2(config['valid']['valid_annot_folder'],
config['valid']['valid_image_folder'], (grid_w, grid_h, num_patches))
if len(valid_imgs) == 0 or len(valid_labels) == 0 or len(valid_labels) != len(valid_imgs):
if len(valid_imgs) == 0:
print("No validation images were found. Invalid validation configuration. Check your config file.")
if len(valid_labels) == 0:
print("No validation labels were found. Invalid validation configuration. Check your config file.")
else:
np.random.seed(10)
train_valid_split = int(0.8*len(train_imgs))
np.random.shuffle(train_imgs)
valid_imgs = train_imgs[train_valid_split:]
train_imgs = train_imgs[:train_valid_split]
print("Validation set:")
print([item['filename'] for item in valid_imgs])
valid_imgs_paths = []
valid_annot_paths = []
for item in valid_imgs:
boxpath = item['boxpath']
imgpath = item['filename']
valid_imgs_paths.append(imgpath)
valid_annot_paths.append(boxpath)
runjson = {}
runjson["run"] = {}
runjson["run"]["valid_images"] = valid_imgs_paths
runjson["run"]["valid_annot"] = valid_annot_paths
if not os.path.exists("runfiles/"):
os.mkdir("runfiles/")
timestr = time.strftime("%Y%m%d-%H%M%S")
with open('runfiles/run_'+timestr+'.json', 'w') as outfile:
json.dump(runjson, outfile)
overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))
# Read first image and check the image depth.
img_first = imagereader.image_read(train_imgs[0]['filename'])
isgrey = False
if len(img_first.shape) == 2:
isgrey = True
elif np.all(img_first[:, :, 0] == img_first[:, :, 1]) and np.all(img_first[:, :, 0] == img_first[:, :, 2]):
isgrey = True
if isgrey:
depth = 1
else:
depth = 3
# As only on box size is expected, the anchor box size automatically
if len(config['model']['anchors'])>2:
# general model is used
img_width = 4096.0 / num_patches
img_height = 4096.0 / num_patches
cell_w = img_width / grid_w
cell_h = img_height / grid_h
anchors = np.array(config['model']['anchors'], dtype=float)
anchors[::2] = anchors[::2]/cell_w
anchors[1::2] = anchors[1::2]/cell_h
else:
# specifc model is used
img_width = float(img_first.shape[1]) / num_patches
img_height = float(img_first.shape[0]) / num_patches
cell_w = img_width / grid_w
cell_h = img_height / grid_h
box_width, box_height = config_tools.get_box_size(config)
anchor_width = 1.0 * box_width / cell_w
anchor_height = 1.0 * box_height / cell_h
anchors = [anchor_width, anchor_height]
if not config['train']['log_path']:
log_path = '~/logs/'
else:
log_path = config['train']['log_path']
if not os.path.exists(log_path):
os.mkdir(log_path)
# Get overlap patches
overlap_patches = 0
if 'overlap_patches' in config['model']:
overlap_patches = int(config['model']['overlap_patches'])
elif not len(config['model']['anchors'])>2:
overlap_patches = config['model']['anchors'][0]
###############################
# Construct the model
###############################
backend_weights = None
if 'backend_weights' in config['model']:
backend_weights = config['model']['backend_weights']
if num_gpus > 1:
import tensorflow as tf
with tf.device('/cpu:0'):
yolo = YOLO(architecture=config['model']['architecture'],
input_size=config['model']['input_size'],
input_depth=depth,
labels=config['model']['labels'],
max_box_per_image=config['model']['max_box_per_image'],
anchors=anchors,
backend_weights=backend_weights)
else:
yolo = YOLO(architecture=config['model']['architecture'],
input_size=config['model']['input_size'],
input_depth=depth,
labels=config['model']['labels'],
max_box_per_image=config['model']['max_box_per_image'],
anchors=anchors,
backend_weights=backend_weights)
###############################
# Load the pretrained weights (if any)
###############################
if os.path.exists(config['train']['pretrained_weights']):
print("Loading pre-trained weights in", config['train']['pretrained_weights'])
yolo.load_weights(config['train']['pretrained_weights'])
# USE MULTIGPU
parallel_model=None
if num_gpus > 1:
parallel_model = multi_gpu_model(yolo.model, gpus=num_gpus)
###############################
# Start the training process
###############################
start = time.time()
parent_pid = os.getpid()
yolo.train(train_imgs=train_imgs,
valid_imgs=valid_imgs,
train_times=config['train']['train_times'],
valid_times=config['valid']['valid_times'],
nb_epoch=config['train']['nb_epoch'],
learning_rate=config['train']['learning_rate'],
batch_size=config['train']['batch_size'],
warmup_epochs=warmup_epochs,
object_scale=config['train']['object_scale'],
no_object_scale=config['train']['no_object_scale'],
coord_scale=config['train']['coord_scale'],
class_scale=config['train']['class_scale'],
saved_weights_name=config['train']['saved_weights_name'],
debug=config['train']['debug'],
log_path=log_path,
early_stop_thresh=early_stop,
num_patches=num_patches,
warm_restarts=warm_restarts,
overlap_patches=overlap_patches,
parallel_model=parallel_model)
end = time.time()
print("Time elapsed for training:", (end - start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment