Skip to content

Instantly share code, notes, and snippets.

@owahltinez
Last active September 28, 2022 01:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save owahltinez/f8de010446b40848a5ea5259885c980a to your computer and use it in GitHub Desktop.
Save owahltinez/f8de010446b40848a5ea5259885c980a to your computer and use it in GitHub Desktop.
Processes a CSV file using AutoML's object detection format into tfrecords.
"""Processes a CSV file using AutoML's object detection format into tfrecords.
This script will accept a CSV file path or URL and write tfrecord files at the
provided output path. Example usage:
```
mkdir -p /tmp/salad_dataset
python tfrecords_from_dataset_descriptor.py \
--csv_path=gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv \
--output_path=/tmp/salad_dataset
```
"""
import pathlib
from absl import app
from absl import flags
from absl import logging
import numpy as np
import pandas as pd
import tensorflow as tf
_FIELDNAMES = ["Subset", "Image", "Label", "XMin", "YMin", "_", "_", "XMax", "YMax", "_", "_"]
flags.DEFINE_string("csv_path", None, "Path to CSV with bounding box annotations.", required=True)
flags.DEFINE_string("output_path", None, "Output location of the tfrecord file(s).", required=True)
flags.DEFINE_integer("records_per_shard", 100, "Number of records per individual tfrecord shard.")
FLAGS = flags.FLAGS
def write_label_map(label_map, output_path):
"""Writes an object_detection string_int_label_map plaintext proto file."""
label_map_items = [
f"item {{\n id: {item_id}\n name: '{item_name}'\n}}"
for item_name, item_id in label_map.items()
]
with tf.io.gfile.GFile(output_path, "w") as f:
f.write("\n".join(label_map_items))
def values_to_tf_feature(values):
"""Converts a list of values to the corresponding tf.train.Feature type."""
# If the argument is a single value, wrap it into a list.
value_list = values if isinstance(values, list) else [values]
# If the type is string, we have to encode it as a byte array.
if isinstance(value_list[0], str):
value_list = [value.encode("utf8") for value in value_list]
# Determine the feature class based on the value type.
if isinstance(value_list[0], bytes):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value_list))
if isinstance(value_list[0], float):
return tf.train.Feature(float_list=tf.train.FloatList(value=value_list))
if isinstance(value_list[0], int):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))
raise TypeError(f"Unknown feature type for values: {type(value_list[0])}.")
def build_tf_example_proto(image_path, detection_boxes, detection_labels):
"""Converts a single object detection image sample to a tf.train.Example proto.
Args:
image_path: path to the image file
detection_boxes: list of [xmin, ymin, xmax, ymax] for each detection box.
detection_labels: list of numerical labels corresponding to each box.
Returns:
The populated tf.train.Example proto.
"""
# Read image from file.
with tf.io.gfile.GFile(image_path, "rb") as f:
image_data = f.read()
image_tensor = tf.io.decode_image(image_data)
image_extension = pathlib.Path(image_path).suffix[1:]
height, width = image_tensor.shape[:2]
# Validate that all boxes use coordinates that are within the [0, 1] range.
bbox_values = np.array(detection_boxes).flatten()
assert bbox_values.min() >= 0 and bbox_values.max() <= 1, "Detection boxes must be normalized."
# Since nested structures are not supported by tf.train.Feature, the
# object_detection API uses separate lists for each of the box coordinates
# to represent detection boxes.
bbox_xmin = [bbox[0] for bbox in detection_boxes]
bbox_ymin = [bbox[1] for bbox in detection_boxes]
bbox_xmax = [bbox[2] for bbox in detection_boxes]
bbox_ymax = [bbox[3] for bbox in detection_boxes]
# Convert the data into a tf.train.Example instance.
return tf.train.Example(
features=tf.train.Features(
feature={
"image/height": values_to_tf_feature(height),
"image/width": values_to_tf_feature(width),
"image/encoded": values_to_tf_feature(image_data),
"image/format": values_to_tf_feature(image_extension),
"image/object/bbox/xmin": values_to_tf_feature(bbox_xmin),
"image/object/bbox/xmax": values_to_tf_feature(bbox_xmax),
"image/object/bbox/ymin": values_to_tf_feature(bbox_ymin),
"image/object/bbox/ymax": values_to_tf_feature(bbox_ymax),
"image/object/class/label": values_to_tf_feature(detection_labels),
}
)
)
def write_tfrecord_chunk(output_path, examples):
"""Writes a list of tf.train.Example instances to a tfrecord file."""
with tf.io.TFRecordWriter(output_path) as writer:
for example in examples:
writer.write(example.SerializeToString())
def main(_):
logger = tf.get_logger()
# Validate that the output path points to a directory.
base_path = pathlib.Path(FLAGS.output_path)
assert base_path.is_dir(), "Parameter --output_path must be an existing directory."
# Read the CSV file and discard unwanted columns.
column_names = [f"_.{i}" if col == "_" else col for i, col in enumerate(_FIELDNAMES)]
with tf.io.gfile.GFile(FLAGS.csv_path) as f:
usecols = [col for col in column_names if not col.startswith("_")]
df = pd.read_csv(f, names=column_names, usecols=usecols, header=None)
logger.info(f"Read {len(df)} records from CSV file.")
# Write out the label map into a pbtxt file for later use.
labels_path = base_path / "labels.pbtxt"
logger.info(f"Writing {labels_path} file.")
label_map = {label: (i + 1) for i, label in enumerate(df.Label.unique())}
write_label_map(label_map, labels_path)
# Iterate over dataset rows by subset (train, test and validate).
df.Subset = df.Subset.str.lower()
for subset, subset_group in df.groupby("Subset"):
logger.info(f"Processing {subset} subset.")
# Group all the detection boxes by image and estimate the total number of shards.
image_groups = subset_group.groupby(df.Image)
num_shards = 1 + len(image_groups) // FLAGS.records_per_shard
logger.info(f"Writing {subset} records into {num_shards} shards.")
# Iterate over the group of detections associated with each image.
tfrecord_index = 0
tfrecord_buffer = []
extract_bbox_from_record = lambda x: [x["XMin"], x["YMin"], x["XMax"], x["YMax"]]
for image_path, image_group in image_groups:
detection_labels = [label_map[x] for x in image_group.Label.values]
detection_boxes = image_group.apply(extract_bbox_from_record, axis=1).values.tolist()
record = build_tf_example_proto(image_path, detection_boxes, detection_labels)
tfrecord_buffer.append(record)
logger.info(f"Buffered {len(tfrecord_buffer)} records.")
# Flush the buffer once we reach the desired number of examples.
if len(tfrecord_buffer) == FLAGS.records_per_shard:
suffix = f"{tfrecord_index:05d}-of-{num_shards:05d}"
output_path = f"{FLAGS.output_path}/data-{subset}.tfrecord-{suffix}"
write_tfrecord_chunk(output_path, tfrecord_buffer)
tfrecord_buffer.clear()
tfrecord_index += 1
# Write out any remaining items in the buffer.
if len(tfrecord_buffer) > 0:
suffix = f"{tfrecord_index:05d}-of-{num_shards:05d}"
output_path = f"{FLAGS.output_path}/data-{subset}.tfrecord-{suffix}"
write_tfrecord_chunk(output_path, tfrecord_buffer)
assert tfrecord_index + 1 == num_shards
logger.info(f"Finished writing {subset} records.")
if __name__ == "__main__":
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment