Skip to content

Instantly share code, notes, and snippets.

@denisb411
Last active March 6, 2020 12:58
Show Gist options
  • Save denisb411/c305090d9036ccff764eacff7b263ef1 to your computer and use it in GitHub Desktop.
Save denisb411/c305090d9036ccff764eacff7b263ef1 to your computer and use it in GitHub Desktop.
Script used to facilitate the process of training of tensorflow object detection API
import os, sys
import shutil
from argparse import ArgumentParser
from random import randrange
from PIL import Image
import glob
import re
import xml.etree.ElementTree as ET
import pandas as pd
import tensorflow as tf
from google.protobuf import text_format
import time
import subprocess
try:
import object_detection
from object_detection.protos import pipeline_pb2
import nets
except (ImportError, ModuleNotFoundError):
os.system('cd model/ && protoc object_detection/protos/*.proto --python_out=.')
os.system('cd model/ && pip install .')
os.system('cd model/slim && python setup.py install ')
from object_detection import pipeline_pb2
os.environ["PYTHONPATH"] = f"{os.path.abspath('./model')};{os.path.abspath('./model/slim')}"
__author__ = ""
__date__ = ""
__version__ = ""
__description__ = ""
if __name__ == '__main__':
argParser = ArgumentParser(description=__description__,
epilog='Developed by ' + __author__ + ' in ' + __date__)
argParser.add_argument('--input-images', dest="input_images", type=str, required=True)
argParser.add_argument('--input-annotations', dest="input_annotations", type=str, required=True)
argParser.add_argument('--num-steps', dest="num_steps", type=int, default=40000)
argParser.add_argument('--pipeline-file', dest="pipeline_file", type=str, default='model/object_detection/samples/configs/faster_rcnn_inception_v2_pets.config')
argParser.add_argument('--config-weights-relation-file', dest="config_weights_relation_file", type=str, default='model/object_detection/samples/configs/faster_rcnn_inception_v2_pets.config')
argParser.add_argument('--continue', dest="continue_training", action='store_true')
args = argParser.parse_args()
ANNOTATIONS_PATH = os.path.abspath(args.input_annotations)
IMAGES_PATH = os.path.abspath(args.input_images)
OD_PATH = os.path.abspath('model/object_detection/')
IMAGES_TEST_PATH = os.path.abspath(OD_PATH + '/images/test/')
IMAGES_TRAIN_PATH = os.path.abspath(OD_PATH + '/images/train/')
CONF_FILES_PATH = './model/object_detection/samples/configs/'
WEIGHTS_PATH = './model/object_detection/weights/'
if args.continue_training:
os.system(f'cd {OD_PATH} && \
python train.py \
--logtostderr \
--train_dir=training/ \
--pipeline_config_path=training/faster_rcnn_inception_v2_pets.config --worker_replicas=2 --ps-tasks=1')
sys.exit(0)
def clean_folder(folder):
if os.path.isdir(folder):
shutil.rmtree(folder)
os.mkdir(folder)
## Clean old train/test sets and other files
clean_folder(IMAGES_TRAIN_PATH)
clean_folder(IMAGES_TEST_PATH)
## Clean training and inference_graph folders ##
clean_folder(OD_PATH + '/inference_graph')
clean_folder(OD_PATH + '/training')
clean_folder(f'{OD_PATH}/images')
os.mkdir(IMAGES_TEST_PATH)
os.mkdir(IMAGES_TRAIN_PATH)
if not os.path.isdir(WEIGHTS_PATH):
os.mkdir(WEIGHTS_PATH)
## Doing splits of 80% and convert images to jpg
for annot_xml in glob.glob(ANNOTATIONS_PATH + '/*.xml' ):
xml_file = annot_xml
image_name = annot_xml.replace('\\', '/').split('/')[-1].split('.')[0]
xml_name = annot_xml.replace('\\', '/').split('/')[-1]
try:
im = Image.open(f'{IMAGES_PATH}/{image_name}.bmp')
except FileNotFoundError:
try:
im = Image.open(f'{IMAGES_PATH}/{image_name}.jpg')
except FileNotFoundError:
print('FileNotFoundError:', image_name)
continue
if randrange(10) > 7:
shutil.copy(xml_file, IMAGES_TEST_PATH)
im.save(f'{IMAGES_TEST_PATH}/{image_name}.jpg')
else:
shutil.copy(xml_file, IMAGES_TRAIN_PATH)
im.save(f'{IMAGES_TRAIN_PATH}/{image_name}.jpg')
## fix annotations file path
def fix_filepath(annot_xml):
root = ET.parse(annot_xml)
image_name = annot_xml.replace('\\', '/').split('/')[-1].split('.')[0] + '.jpg'
image_path = '/'.join(os.path.abspath(annot_xml).split('/')[:-1]) + '/' + image_name
root.find('filename').text = image_name
root.find('path').text = image_path
root.write(open(annot_xml, 'wb'))
for annot_xml in glob.glob(IMAGES_TEST_PATH + '/*.xml'):
fix_filepath(annot_xml)
for annot_xml in glob.glob(IMAGES_TRAIN_PATH + '/*.xml'):
fix_filepath(annot_xml)
os.system(f'cd {OD_PATH} && python ./xml_to_csv.py')
## write labels pbtxt file
file = pd.read_csv(f'{OD_PATH}/images/train_labels.csv')
categories = file['class'].unique()
end = '\n'
s = ' '
class_map = {}
for ID, name in enumerate(categories):
out = ''
out += 'item' + s + '{' + end
out += s*2 + 'id:' + ' ' + (str(ID+1)) + end
out += s*2 + 'name:' + ' ' + '\'' + name + '\'' + end
out += '}' + end*2
with open(f'{OD_PATH}/training/labelmap.pbtxt', 'a') as f:
f.write(out)
class_map[name] = ID+1
os.system(f"cd {OD_PATH} && \
python generate_tfrecord.py \
--csv_input=images/train_labels.csv \
--image_dir=images/train \
--output_path=train.record \
--classes {' '.join(file['class'].unique())}")
os.system(f"cd {OD_PATH} && \
python generate_tfrecord.py \
--csv_input=images/test_labels.csv \
--image_dir=images/test \
--output_path=test.record \
--classes {' '.join(file['class'].unique())}")
if args.config_weights_relation_file:
config_relations_df = pd.read_csv(args.config_weights_relation_file, header=None)
all_logs_dir = os.path.abspath('./train-logs')
if not os.path.isdir(all_logs_dir):
os.mkdir(all_logs_dir)
current_logs_dir_name = time.strftime("%Y%m%d-%H%M%S")
current_logs_dir = os.path.abspath(f'{all_logs_dir}/{current_logs_dir_name}')
os.mkdir(current_logs_dir)
for idx, row in config_relations_df.iterrows():
pipeline_file = CONF_FILES_PATH + '/' + row[0]
pipeline_file_name = pipeline_file.split('/')[-1].split('.')[0]
url_download_weight = row[1]
weight_file_name = '.'.join(row[1].split('/')[-1].split('.')[:-2])
weight_file = os.path.join(WEIGHTS_PATH, row[1].split('/')[-1])
weight_path = os.path.join(WEIGHTS_PATH, weight_file_name)
if not os.path.isdir(weight_path):
if not os.path.isfile(weight_file):
res = subprocess.run(f'wget -O {weight_file} {url_download_weight}', shell=True)
res = subprocess.run(f'tar -xzf {weight_file} -C {WEIGHTS_PATH} && rm {weight_file}', shell=True)
train_path = os.path.abspath(f'{OD_PATH}/training/{weight_file_name}-{pipeline_file_name}')
if os.path.isdir(train_path):
shutil.rmtree(train_path)
os.mkdir(train_path)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(pipeline_file, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
if 'ssd' in pipeline_file:
pipeline_config.model.ssd.num_classes = len(file['class'].unique())
else:
pipeline_config.model.faster_rcnn.num_classes = len(file['class'].unique())
pipeline_config.train_config.fine_tune_checkpoint = os.path.abspath(f'{weight_path}/model.ckpt')
pipeline_config.train_input_reader.tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/train.record')]
pipeline_config.train_input_reader.label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt')
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/test.record')]
pipeline_config.eval_input_reader[0].label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt')
pipeline_config.train_config.num_steps = args.num_steps
config_text = text_format.MessageToString(pipeline_config)
pipeline_file_name = pipeline_file.replace('\\', '/').split('/')[-1]
pipeline_file_path = f'{train_path}/{pipeline_file_name}'
def write_pipeline_conf():
with tf.gfile.Open(f"{pipeline_file_path}", "wb") as f:
f.write(config_text)
write_pipeline_conf()
current_train_dir = os.path.abspath(f'{current_logs_dir}/{weight_file_name}-{pipeline_file_name}')
os.mkdir(current_train_dir)
train_logs_file = os.path.abspath(f"{current_train_dir}/logs.txt")
print(f"Training model {weight_file_name} using configs {pipeline_file_name}")
def start_training(f):
return subprocess.run(["python", f"{OD_PATH}/train.py",
"--train_dir", train_path,
"--pipeline_config_path", pipeline_file_path],
stdout=f, stderr=f)
with open(train_logs_file, "wb") as f:
proc_result = start_training(f)
if proc_result.returncode != 0:
pipeline_config.train_config.from_detection_checkpoint = False
proc_result = start_training(f)
if proc_result.returncode != 0:
pipeline_config.train_config.from_detection_checkpoint = True
proc_result = start_training(f)
if proc_result.returncode != 0:
pipeline_config.train_config.fine_tune_checkpoint_type = "detection"
proc_result = start_training(f)
with open(f"{current_logs_dir}/general-train-results.txt", "a") as f:
if proc_result.returncode != 0:
f.write(f"problem during train of model {weight_file_name} using configs {pipeline_file_name}\n")
else:
f.write(f"succesfully trained model {weight_file_name} using configs {pipeline_file_name}\n")
else:
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(pipeline_file, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
pipeline_config.model.faster_rcnn.num_classes = len(file['class'].unique())
pipeline_config.train_config.fine_tune_checkpoint = os.path.abspath(f'{OD_PATH}/faster_rcnn_inception_v2_coco_2018_01_28/model.ckpt')
pipeline_config.train_input_reader.tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/train.record')]
pipeline_config.train_input_reader.label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt')
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/test.record')]
pipeline_config.eval_input_reader[0].label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt')
pipeline_config.train_config.num_steps = args.num_steps
config_text = text_format.MessageToString(pipeline_config)
pipeline_file_name = args.pipeline_file.replace('\\', '/').split('/')[-1]
with tf.gfile.Open(f"{OD_PATH}/training/{pipeline_file_name}", "wb") as f:
f.write(config_text)
os.system(f'cd {OD_PATH} && \
python train.py \
--logtostderr \
--train_dir=training/ \
--pipeline_config_path=training/faster_rcnn_inception_v2_pets.config --worker_replicas=2 --ps-tasks=1')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment