Skip to content

Instantly share code, notes, and snippets.

@deeDude
Last active August 1, 2021 12:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deeDude/c40fa1f14e4fa4b7f2ef149e5a344023 to your computer and use it in GitHub Desktop.
Save deeDude/c40fa1f14e4fa4b7f2ef149e5a344023 to your computer and use it in GitHub Desktop.
import os
import tensorflow as tf
from PIL import Image
from functools import partial
def read_tfrecord(example):
tfrecord = {
"image": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64),
}
example = tf.io.parse_single_example(example, tfrecord)
image = tf.image.decode_jpeg(example['image'], channels=3)
label = tf.cast(example['label'], tf.int32)
return image, label
def read_dataset(dataset_path):
filenames = tf.io.gfile.glob(dataset_path + '/validation/*.tfrecord')
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord)
dataset = dataset.repeat()
dataset = dataset.batch(128)
# read and save the first image of the first batch, for test purposes
for image, label in dataset.take(1):
Image.fromarray((image[0].numpy())).save('./images/img.jpeg')
def process_image(image_tensor):
image = tf.image.convert_image_dtype(image_tensor, dtype=tf.uint8)
image = tf.io.encode_jpeg(image, quality=100)
return image
def make_example(encoded_image, label):
image_feature = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[
encoded_image
])
)
label_feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=[
label
])
)
features = tf.train.Features(feature={
'image': image_feature,
'label': label_feature
})
example = tf.train.Example(features=features)
return example.SerializeToString()
def create_tfrecord_shards(ds, num_shards, path):
for shard in range(num_shards):
ds_shard = ds.shard(num_shards, shard).as_numpy_iterator()
with tf.io.TFRecordWriter(path=path.format(shard)) as f:
for encoded_image, label in ds_shard:
example = make_example(encoded_image, label)
f.write(example)
def images2tfrecords():
load_split = partial(
tf.keras.preprocessing.image_dataset_from_directory,
'<image_dataset_dir_path>',
validation_split=0.2,
shuffle=True,
seed=123,
image_size=(600, 600),
batch_size=1,
)
ds_train = load_split(subset='training')
ds_valid = load_split(subset='validation')
ds_train_encoded = ds_train.unbatch().map(process_image)
ds_valid_encoded = ds_valid.unbatch().map(process_image)
training_dir = os.path.join('<tfrecords_dir_path>', 'training')
if not os.path.exists(training_dir):
os.makedirs(training_dir)
training_shard_path = os.path.join(training_dir, 'shard_{:02d}.tfrecord')
create_tfrecord_shards(ds_train_encoded, 32, training_shard_path)
validation_dir = os.path.join('<tfrecords_dir_path>', 'validation')
if not os.path.exists(validation_dir):
os.makedirs(validation_dir)
validation_shard_path = os.path.join(validation_dir, 'shard_{:02d}.tfrecord')
create_tfrecord_shards(ds_valid_encoded, 8, validation_shard_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment