This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
from PIL import Image | |
from functools import partial | |
def read_tfrecord(example): | |
tfrecord = { | |
"image": tf.io.FixedLenFeature([], tf.string), | |
"label": tf.io.FixedLenFeature([], tf.int64), | |
} | |
example = tf.io.parse_single_example(example, tfrecord) | |
image = tf.image.decode_jpeg(example['image'], channels=3) | |
label = tf.cast(example['label'], tf.int32) | |
return image, label | |
def read_dataset(dataset_path): | |
filenames = tf.io.gfile.glob(dataset_path + '/validation/*.tfrecord') | |
dataset = tf.data.TFRecordDataset(filenames) | |
dataset = dataset.map(read_tfrecord) | |
dataset = dataset.repeat() | |
dataset = dataset.batch(128) | |
# read and save the first image of the first batch, for test purposes | |
for image, label in dataset.take(1): | |
Image.fromarray((image[0].numpy())).save('./images/img.jpeg') | |
def process_image(image_tensor): | |
image = tf.image.convert_image_dtype(image_tensor, dtype=tf.uint8) | |
image = tf.io.encode_jpeg(image, quality=100) | |
return image | |
def make_example(encoded_image, label): | |
image_feature = tf.train.Feature( | |
bytes_list=tf.train.BytesList(value=[ | |
encoded_image | |
]) | |
) | |
label_feature = tf.train.Feature( | |
int64_list=tf.train.Int64List(value=[ | |
label | |
]) | |
) | |
features = tf.train.Features(feature={ | |
'image': image_feature, | |
'label': label_feature | |
}) | |
example = tf.train.Example(features=features) | |
return example.SerializeToString() | |
def create_tfrecord_shards(ds, num_shards, path): | |
for shard in range(num_shards): | |
ds_shard = ds.shard(num_shards, shard).as_numpy_iterator() | |
with tf.io.TFRecordWriter(path=path.format(shard)) as f: | |
for encoded_image, label in ds_shard: | |
example = make_example(encoded_image, label) | |
f.write(example) | |
def images2tfrecords(): | |
load_split = partial( | |
tf.keras.preprocessing.image_dataset_from_directory, | |
'<image_dataset_dir_path>', | |
validation_split=0.2, | |
shuffle=True, | |
seed=123, | |
image_size=(600, 600), | |
batch_size=1, | |
) | |
ds_train = load_split(subset='training') | |
ds_valid = load_split(subset='validation') | |
ds_train_encoded = ds_train.unbatch().map(process_image) | |
ds_valid_encoded = ds_valid.unbatch().map(process_image) | |
training_dir = os.path.join('<tfrecords_dir_path>', 'training') | |
if not os.path.exists(training_dir): | |
os.makedirs(training_dir) | |
training_shard_path = os.path.join(training_dir, 'shard_{:02d}.tfrecord') | |
create_tfrecord_shards(ds_train_encoded, 32, training_shard_path) | |
validation_dir = os.path.join('<tfrecords_dir_path>', 'validation') | |
if not os.path.exists(validation_dir): | |
os.makedirs(validation_dir) | |
validation_shard_path = os.path.join(validation_dir, 'shard_{:02d}.tfrecord') | |
create_tfrecord_shards(ds_valid_encoded, 8, validation_shard_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment