Created
January 4, 2024 17:19
-
-
Save Hermanoid/991a0f96222c7b27e02bafb70d3d9eb4 to your computer and use it in GitHub Desktop.
TFRecord Image Infuser (for free-tier CVAT exports with no images)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
try: | |
from tkinter import Tk | |
from tkinter.filedialog import askopenfilename, askdirectory | |
except ImportError: | |
headless = True | |
print("Welcome to the TFRecord Image Infuser!") | |
print("TensorFlow version: {}".format(tf.__version__)) | |
print("Eager execution: {}".format(tf.executing_eagerly())) | |
def get_directory(prompt, initialdir): | |
if headless: | |
# get file from terminal | |
response = input(f"{prompt} (default= {initialdir}): ") | |
return response.replace("'", "") if response else initialdir | |
else: | |
return askdirectory(title=prompt, initialdir=initialdir) | |
# Prompt for input .tfrecord file | |
if headless: | |
input_file = input("Select input .tfrecord file: ").replace("'", "") | |
else: | |
Tk().withdraw() | |
input_file = askopenfilename(title="Select input .tfrecord file") | |
# Prompt for images folder | |
default_images_folder = os.path.join(os.path.dirname(input_file), "images") | |
if os.path.isdir(default_images_folder): | |
initial_folder = default_images_folder | |
else: | |
initial_folder = os.path.dirname(input_file) | |
images_folder = get_directory("Select images folder", initialdir=initial_folder) | |
dataset = tf.data.TFRecordDataset(input_file) | |
def _bytes_feature(value): | |
"""Returns a bytes_list from a string / byte.""" | |
if isinstance(value, type(tf.constant(0))): | |
value = value.numpy() # BytesList won't unpack a string from an EagerTensor. | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def read_files(record): | |
ex = tf.train.Example.FromString(record.numpy()) | |
example = dict(ex.features.feature) | |
# Get the filename | |
filename = example['image/filename'].bytes_list.value[0].decode('utf-8') | |
# Use tf.py_function to read the image file | |
image_data = read_image_file(filename) | |
# image_data = tf.py_function(func=read_image_file, inp=[filename], Tout=tf.string) | |
# Add the image data to the example | |
example['image/encoded'] = image_data | |
format = "image/" + filename.split(".")[-1] | |
example["image/format"] = _bytes_feature(format.encode("utf-8")) | |
# Serialize the example and return it | |
return tf.train.Example(features=tf.train.Features(feature=example)).SerializeToString() | |
def read_image_file(filename): | |
# Read the image file | |
with open(os.path.join(images_folder, filename), 'rb') as f: | |
return _bytes_feature(f.read()) | |
# Write the result to a new TFRecord file | |
output_file = os.path.splitext(input_file)[0] + "_with_images.tfrecord" | |
with tf.io.TFRecordWriter(output_file) as writer: | |
for i, record in dataset.enumerate(): | |
print("\rProcessing record {}".format(i+1), end="") | |
new_record = read_files(record) | |
writer.write(new_record) | |
print("\nDone!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment