Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import os
import tensorflow as tf
from PIL import Image
from functools import partial
def read_tfrecord(example):
tfrecord = {
"image": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64),
}
example = tf.io.parse_single_example(example, tfrecord)
image = tf.image.decode_jpeg(example['image'], channels=3)
label = tf.cast(example['label'], tf.int32)
return image, label
def read_dataset(dataset_path):
filenames = tf.io.gfile.glob(dataset_path + '/validation/*.tfrecord')
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord)
dataset = dataset.repeat()
dataset = dataset.batch(128)
# read and save the first image of the first batch, for test purposes
for image, label in dataset.take(1):
Image.fromarray((image[0].numpy())).save('./images/img.jpeg')
def process_image(image_tensor):
image = tf.image.convert_image_dtype(image_tensor, dtype=tf.uint8)
image = tf.io.encode_jpeg(image, quality=100)
return image
def make_example(encoded_image, label):
image_feature = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[
encoded_image
])
)
label_feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=[
label
])
)
features = tf.train.Features(feature={
'image': image_feature,
'label': label_feature
})
example = tf.train.Example(features=features)
return example.SerializeToString()
def create_tfrecord_shards(ds, num_shards, path):
for shard in range(num_shards):
ds_shard = ds.shard(num_shards, shard).as_numpy_iterator()
with tf.io.TFRecordWriter(path=path.format(shard)) as f:
for encoded_image, label in ds_shard:
example = make_example(encoded_image, label)
f.write(example)
def images2tfrecords():
load_split = partial(
tf.keras.preprocessing.image_dataset_from_directory,
'<image_dataset_dir_path>',
validation_split=0.2,
shuffle=True,
seed=123,
image_size=(600, 600),
batch_size=1,
)
ds_train = load_split(subset='training')
ds_valid = load_split(subset='validation')
ds_train_encoded = ds_train.unbatch().map(process_image)
ds_valid_encoded = ds_valid.unbatch().map(process_image)
training_dir = os.path.join('<tfrecords_dir_path>', 'training')
if not os.path.exists(training_dir):
os.makedirs(training_dir)
training_shard_path = os.path.join(training_dir, 'shard_{:02d}.tfrecord')
create_tfrecord_shards(ds_train_encoded, 32, training_shard_path)
validation_dir = os.path.join('<tfrecords_dir_path>', 'validation')
if not os.path.exists(validation_dir):
os.makedirs(validation_dir)
validation_shard_path = os.path.join(validation_dir, 'shard_{:02d}.tfrecord')
create_tfrecord_shards(ds_valid_encoded, 8, validation_shard_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment