Skip to content

Instantly share code, notes, and snippets.

@Hermanoid
Created January 4, 2024 17:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Hermanoid/991a0f96222c7b27e02bafb70d3d9eb4 to your computer and use it in GitHub Desktop.
Save Hermanoid/991a0f96222c7b27e02bafb70d3d9eb4 to your computer and use it in GitHub Desktop.
TFRecord Image Infuser (for free-tier CVAT exports with no images)
import tensorflow as tf
import os
try:
from tkinter import Tk
from tkinter.filedialog import askopenfilename, askdirectory
except ImportError:
headless = True
print("Welcome to the TFRecord Image Infuser!")
print("TensorFlow version: {}".format(tf.__version__))
print("Eager execution: {}".format(tf.executing_eagerly()))
def get_directory(prompt, initialdir):
if headless:
# get file from terminal
response = input(f"{prompt} (default= {initialdir}): ")
return response.replace("'", "") if response else initialdir
else:
return askdirectory(title=prompt, initialdir=initialdir)
# Prompt for input .tfrecord file
if headless:
input_file = input("Select input .tfrecord file: ").replace("'", "")
else:
Tk().withdraw()
input_file = askopenfilename(title="Select input .tfrecord file")
# Prompt for images folder
default_images_folder = os.path.join(os.path.dirname(input_file), "images")
if os.path.isdir(default_images_folder):
initial_folder = default_images_folder
else:
initial_folder = os.path.dirname(input_file)
images_folder = get_directory("Select images folder", initialdir=initial_folder)
dataset = tf.data.TFRecordDataset(input_file)
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def read_files(record):
ex = tf.train.Example.FromString(record.numpy())
example = dict(ex.features.feature)
# Get the filename
filename = example['image/filename'].bytes_list.value[0].decode('utf-8')
# Use tf.py_function to read the image file
image_data = read_image_file(filename)
# image_data = tf.py_function(func=read_image_file, inp=[filename], Tout=tf.string)
# Add the image data to the example
example['image/encoded'] = image_data
format = "image/" + filename.split(".")[-1]
example["image/format"] = _bytes_feature(format.encode("utf-8"))
# Serialize the example and return it
return tf.train.Example(features=tf.train.Features(feature=example)).SerializeToString()
def read_image_file(filename):
# Read the image file
with open(os.path.join(images_folder, filename), 'rb') as f:
return _bytes_feature(f.read())
# Write the result to a new TFRecord file
output_file = os.path.splitext(input_file)[0] + "_with_images.tfrecord"
with tf.io.TFRecordWriter(output_file) as writer:
for i, record in dataset.enumerate():
print("\rProcessing record {}".format(i+1), end="")
new_record = read_files(record)
writer.write(new_record)
print("\nDone!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment