Skip to content

Instantly share code, notes, and snippets.

@masahi
Created February 23, 2023 21:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masahi/b029081c094dd427c94a7ba4aab7a9ad to your computer and use it in GitHub Desktop.
Save masahi/b029081c094dd427c94a7ba4aab7a9ad to your computer and use it in GitHub Desktop.

Exporting TF2 detection models to a TVM friendly format

Problem

Models available in https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md are not suitable for TVM ingestion, because NMS is unrolled per class, creating a unnecessarily complicated model.

Solution

We can use the model export script https://github.com/tensorflow/models/blob/master/research/object_detection/exporter_main_v2.py together with use_combined_nms: true flag in a model config to re-export a model into a compact format tI know what is expected of me at work.hat can be imported into Relay easily. This will cause the model to use this code path https://github.com/tensorflow/models/blob/238922e98dd0e8254b5c0921b241a1f5a151782f/research/object_detection/core/post_processing.py#L1003 in its post processing module.

However, there are three complications:

My fork https://github.com/masahi/models/tree/export-for-tvm contains necessary changes outlined above.

Steps

  1. Apply two changes to the model zoo source code to make it ready for TVM-friendly export

    1. Option a: Use my fork above

    2. Option b: Apply the change manually

      1. Merge Trever's PR tensorflow/models#9707
      2. Replace the code at https://github.com/tensorflow/models/blob/7f0ee4cb1f10d4ada340cc5bfe2b99d0d690b219/research/object_detection/exporter_lib_v2.py#L106-L112 by the following:
      images, true_shapes = _decode_and_preprocess(batch_input[0])
      return tf.expand_dims(images, 0), tf.expand_dims(true_shapes, 0)
  2. Install model package by following https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md#python-package-installation

  3. Download a model checkpoint from https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md. After unpacking, we should have a directory like ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8

  4. Inside the model directory, there is a file named pipeline.config. There is a post_processing section, something like

    post_processing {
          batch_non_max_suppression {
            score_threshold: 9.99999993922529e-09
            iou_threshold: 0.6000000238418579
            max_detections_per_class: 100
            max_total_detections: 100
            use_static_shapes: false
          }
          score_converter: SIGMOID
        }

    Add use_combined_nms: true to batch_non_max_suppresion section. It should look like

    post_processing {
          batch_non_max_suppression {
            use_combined_nms: true
            score_threshold: 9.99999993922529e-09
            iou_threshold: 0.6000000238418579
            max_detections_per_class: 100
            max_total_detections: 100
            use_static_shapes: false
          }
          score_converter: SIGMOID
        }
  5. Under models/research directory, create a file (export.sh, say), and add the following:

    INPUT_TYPE=image_tensor
    PIPELINE_CONFIG_PATH=/home/masa/ssd_mobilenet_v2_320x320_coco17_tpu-8/pipeline.config
    TRAINED_CKPT_PREFIX=/home/masa/ssd_mobilenet_v2_320x320_coco17_tpu-8/checkpoint
    EXPORT_DIR=/home/masa/ssd_mobilenet_v2_320x320_coco17_tpu-8/export
    python object_detection/exporter_main_v2.py \
        --input_type=${INPUT_TYPE} \
        --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
        --trained_checkpoint_dir=${TRAINED_CKPT_PREFIX} \
        --output_directory=${EXPORT_DIR} \
  6. Run export.sh. EXPORT_DIR should have saved_model directory. This is the input to tf2onnx tool.

  7. Go to EXPORT_DIR and run (IMPORTANT: Use opset 12)

    python -m tf2onnx.convert --saved-model saved_model --opset 12 --output model.onnx
  8. View the exported onnx model in netron to verify that NMS is not unrolled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment