Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
TensorFlow Object Detection Model Training

How to train your own object detection models using the TensorFlow Object Detection API (2020 Update)

This started as a summary of this nice tutorial, but has since then become its own thing.

Prerequisites

  1. Choose a TensorFlow installation. TensorFlow 1 and 2 have different different neural networks avaliable, so check here and here to make your choice.

    • Tip: if you opt for one of the TF1 models, please note that the Ojbect detection API is only officialy compatible with TF 1.15.O, which works only with CUDA 10.0 (unless you compile from source). From personal experience, I know that all versions of TF from 1.12 and backwards do not work with the Object Detection API anymore.
  2. Install TensorFlow.

  3. Download the TensorFlow models repository.

Annotating images and serializing the dataset

For these steps, I'll recommend a collection of script I mase, which are available in this repository. All of the scripts mentioned in this section receive arguments from the command line and have help messages through the -h/--help flags. Also check the README from the repo they come from to get more details, if needed.

  1. Install labelImg. This is a Python package, which means you can install it via pip, but the one from GitHub is better.

  2. Annotate your dataset using labelImg. Each image you annotate will have its annotations saved to an individual XML file with the name of the original image file and the .xml extension.

  3. Use this script to convert the XML files generated by labelImg into a single CSV file.

    • Optional: Use this script to separate the CSV file into two, one with training examples and one with evaluation examples. Let's call them train.csv and eval.csv. Images will be selected randomly and there are options to stratify examples by class, making sure that objects from all classes are present in both datasets. The usual proportions are 75% to 80% of the annotated objects used for training and the rest for the evaluation dataset.
  4. Create a "label map" for your classes. You can check some examples to understand what they look like. You can also generate one from your original CSV file with this script.

  5. Use this script to convert each of your CSV files into two TFRecord files (eg. train.record and eval.record), a serialized data format that TensorFlow is most familiar with. You'll need to point to the directory where the image files are stored and to the label map generated in the previous step.

    • Tip: if you notice mistakes during the creation of these files, you can check their contents and compare to the ones in these examples.

Choosing a neural network and preparing the training pipeline

  1. Download your the neural network model of choice from either the Detection Model Zoo [TF1][TF2] or from the models trained for classification available here and here. This is the step in which your choice of TensorFlow version will make a difference. From my experience, many of the classification models work with TF 1.15, but I am not aware if they work with TF 2.

  2. Provide a training pipeline, which is a file with .config extension that describes the training procedure. The models provided in the Detection Zoo come with their own pipelines inside their .tar.gz file, but the classification models do not. In this situation, your options are to:

    • download one that is close enough from here (I have succesfully done that to train classification MobileNets V1, V2 and V3 for detection).
    • create your own, by following this tutorial.

    The pipeline config file has some fields that must be adjusted before training is started. The first thing you'll definitely want to keep an eye on is the num_classes attribute, which you'll need to change to the number of classes in your personal dataset.

    Other importants fields are the ones with the PATH_TO_BE_CONFIGURED string. In these fields, you'll need to point to the files they ask for, such as the label map, the training and evaluation TFRecords and the neural network checkpoint, which is a file with an extension like .ckpt or .ckpt.data-####-of-####. This file also comes with the .tar.gz file.

    In case you are using a model from the Detection Zoo, set the fine_tune_checkpoint_type field to "detection", otherwise, set it to "classification".

    There are additional parameters that may affect how much RAM is consumed by the training process, as well as the quality of the training. Things like the batch size or how many batches TensorFlow can prefetch and keep in memory may considerably increase the amount of RAM necessary, but I won't go over those here as there is too much trial and error in adjusting those.

Training the network

  1. Train the model. To do it locally, follow the steps available here: [TF1][TF2]. Optional: in order to check training progress, TensorBoard can be started pointing its --logdir to the --model_dir of object_detection/model_main.py.

  2. Export the network, like this.

  3. Use the exported .pb in your object detector.

Final Tips

In the data augmentation section of the training pipeline, some options can be added or removed to try and make the training better. Some of the options are listed here.

@k1732036

This comment has been minimized.

Copy link

k1732036 commented Mar 28, 2018

It became easy for me to create my own object detector after reading your article.
Thank you!

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Jun 1, 2018

Hey, I'm glad you liked it. If you have any suggestions, leave a comment or maybe edit the article. Also, if anyone is interested, I just created a Python package that abstracts some of the more gruesome parts of using TensorFlow Object Detection API.

@freedylam

This comment has been minimized.

Copy link

freedylam commented Jun 7, 2018

Hi, douglasrizzo.

Do you mind list out how you put the folder of image, xml, csv, and TFRecord and config and how you issue the command?
I think I was missing in some part, when I doing training the train.py always get Killed.

already using CPU version of TF. and have 4GB memory of my ubuntu.

Thanks,
Freedy

@devajith

This comment has been minimized.

Copy link

devajith commented Sep 18, 2018

Can we do the splitting CSV by script ?

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Oct 10, 2018

@devajith you can, but it's a little tricky. You have to make sure you have examples of all classes in both the training and testing CSVs.

@Tsarpf

This comment has been minimized.

Copy link

Tsarpf commented Jan 8, 2019

The link to the package is dead. But thanks a lot for the cleaned up tutorial!

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Mar 18, 2019

Hey @Tsarpf, I just fixed the link to the package.

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Apr 23, 2019

Hi everyone, I've just updated the tutorial, mentioning a few scripts I created to help separate the main CSV into train and evaluation datasets, as well as generating the label map from the main CSV file. Since I've been the only one who tested these scripts, any issues or PRs in their repo are welcome.

@Rakshasv18

This comment has been minimized.

Copy link

Rakshasv18 commented May 2, 2019

@douglasrizzo hey ! getting error as No module named pandas! even when its installed

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented May 2, 2019

@Rakshasv18 make sure you're using the Python version that has pandas installed. You may have Python 2 and 3 installed, or multiple versions of Python 3.

@Friso1987

This comment has been minimized.

Copy link

Friso1987 commented Jul 12, 2019

When I use the XML_to_CSV script my filenames are not included in the new CSV file. It fills them in as two dashes "--". How do I include my filename in convertion?

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Jul 12, 2019

@Friso1987 could you please check if your XML files contain the file names in them? Also, in which OS are you using the generate_csv.py script? I suspect that it won't work properly on Windows, since I've tested it only on Linux. You may submit an issue in the repo if you like, so I can take a look at it.

@Friso1987

This comment has been minimized.

Copy link

Friso1987 commented Jul 13, 2019

Found the mistake (stupid me). I got my XML files from an external source. When opening the XML files, I found out the filename was changed to "--". Probably for privacy reasons. I simply fixed it by saving the .xml files again in the LabelImg application, so it had a correct filename.

@MasudHaider

This comment has been minimized.

Copy link

MasudHaider commented Sep 24, 2019

Can this tutorial be run on google colaboratory?

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Oct 10, 2019

@MasudHaider I imagine so, as long as you are able to clone and install the TF Object Detection API into your Colab, the other steps are simple.

@Idakwo

This comment has been minimized.

Copy link

Idakwo commented Oct 28, 2019

@douglasrizzo Can you please share a file to recreate the environment you used? Like a requirement.txt or .yml. I am getting one error after the other after fixing a dependency error. There are several GitHub commits to TF's Object Detection repo correlating with different TF versions. I see changes in the repo to TF2.0. I'd be great to know which commit worked for you and the corresponding set up. Thanks

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Oct 29, 2019

Hi @ldakwo. I haven't updated this tutorial in a while. It works with tensorflow==1.13, but I can't pinpoint the version of the
tensorflow/models repo that is compatible with tf 1.13. I also know that object detection does not work with tf2.0 as of yet, so maybe that's one source for you errors.

@marcotacuri

This comment has been minimized.

Copy link

marcotacuri commented Mar 24, 2020

cuando ejecuto la conversion de csv a tfrecords aparece el siguiente mensaje: module 'tensorflow' has no attribute 'app' en la linea: flags = tf.app.flags. Como soluciono esto

@padhulp

This comment has been minimized.

Copy link

padhulp commented Jul 13, 2020

Your scripts helped me create my TF records atlast! Thank you so much

@douglasrizzo

This comment has been minimized.

Copy link
Owner Author

douglasrizzo commented Jul 13, 2020

@padhulp glad they helped!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.