Skip to content

Instantly share code, notes, and snippets.

@AdityaKane2001
Created June 21, 2021 05:27
Show Gist options
  • Save AdityaKane2001/ba11ff7f34a8da86e8cb511bc9d03942 to your computer and use it in GitHub Desktop.
Save AdityaKane2001/ba11ff7f34a8da86e8cb511bc9d03942 to your computer and use it in GitHub Desktop.
"""Script to make and save TFRecords from ImageNet files"""
import tensorflow as tf
import os
import random
import json
import math
import time
from .image_utils import *
from typing import Tuple, List
_last_written_shards = 0
_start_time = time.time()
_logging_gap = 0
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = (
value.numpy()
) # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# make_tfrecs -> for(_make_single_tfrecord)-> make_dataset -> _make_example
def _get_synset_labels(filepath: str) -> dict:
"""
Gets synsets from json file in a dict
Args:
filepath: json file path
Returns:
Dict having the following structure:
{str id : (int synset_ID, str label_name )}
"""
with open(filepath, "r") as f:
raw_labels_dict = json.load(f)
labels_dict = dict()
for i in raw_labels_dict:
labels_dict[raw_labels_dict[i]["id"]] = (
int(i),
raw_labels_dict[i]["label"],
)
return labels_dict
def _get_default_synset_path() -> str:
self_path = __file__
path_segments = self_path.split('/')
regnety_path = '/'.join(path_segments[:-2])
return os.path.join(regnety_path, 'config', 'imagenet_synset_to_human.json')
def _get_files(
data_dir: str,
synset_filepath: str,
shuffle: bool = True) -> Tuple[
List[str],
List[int],
List[str]
]:
"""
Returns lists of all files, their integer labels and their synsets
Args:
data_dir: directory containing ImageNet-style directory structure
(synsets ID as directory names, images inside)
synset_filepath: path to synsets json file
shuffle: True if data needs to be shuffled
Returns:
all_images: paths to all image files
all_labels: integer labels corresponding to images in all_images list
all_synsets: synset strings corresponding to images in all_images list
"""
all_images = tf.io.gfile.glob(os.path.join(data_dir, "*", "*.JPEG"))
all_synsets = [os.path.basename(os.path.dirname(f)) for f in all_images]
print(len(all_images))
all_indexes = list(range(len(all_images)))
if shuffle:
random.shuffle(all_indexes)
all_images = [all_images[i] for i in all_indexes]
all_synsets = [all_synsets[i][1:] + "-n" for i in all_indexes]
labels_dict = _get_synset_labels(synset_filepath)
all_labels_int = [labels_dict[i][0] for i in all_synsets]
return all_images, all_labels_int, all_synsets
def _make_image(filepath: str, shard: int) -> Tuple[str, int, int]:
"""
Reads an image and returns its raw byte string. Converts all images to JPEG
RGB.
Args:
filepath: path to .JPEG image
Returns:
A byte string of the image with JPEG RGB format.
"""
global _start_time, _logging_gap, _last_written_shards
if time.time() - _start_time >= _logging_gap:
print('%d shards were completed in the last %d seconds.' %
(shard - _last_written_shards, _logging_gap))
_last_written_shards = shard
_start_time = time.time()
image_str = tf.io.read_file(filepath)
if is_png(filepath):
image_str = png_to_jpeg(image_str)
if is_cmyk(filepath):
image_str = cmyk_to_rgb(image_str)
image_tensor = tf.io.decode_jpeg(image_str)
height, width = image_tensor.shape[0], image_tensor.shape[1]
if not is_rgb(image_tensor):
image_tensor = tf.image.grayscale_to_rgb(image_tensor)
image_str = tf.io.encode_jpeg(image_tensor)
assert len(image_tensor.shape) == 3
return image_str, height, width
def _make_example(
image_str: bytes,
height: int,
width: int,
filepath: str,
label: int,
synset: str) -> tf.train.Example:
"""
Makes a single example from arguments
Args:
image_str: bytes string of image in JPEG RGB format
filepath: path to image
height: height of image in pixels
width: width of image in pixels
label: integer denoting label
synset: synset string corresponding to image
Returns:
A tf.train.Example having aforementioned attributes
"""
try:
example = tf.train.Example(
features=tf.train.Features(
feature={
"image": _bytes_feature(image_str),
"height": _int64_feature(height),
"width": _int64_feature(width),
"filename": _bytes_feature(
bytes(os.path.basename(filepath)).encode("utf8")
),
"label": _int64_feature(label),
"synset": _bytes_feature(bytes(synset).encode("utf8")),
}
)
)
except TypeError:
example = tf.train.Example(
features=tf.train.Features(
feature={
"image": _bytes_feature(image_str),
"height": _int64_feature(height),
"width": _int64_feature(width),
"filename": _bytes_feature(
bytes(os.path.basename(filepath), encoding="utf8")
),
"label": _int64_feature(label),
"synset": _bytes_feature(bytes(synset, encoding="utf8")),
}
)
)
return example
def _make_single_tfrecord(
chunk_files: List[str],
chunk_synsets: List[str],
chunk_labels: List[int],
output_filepath: str,
shard: int
):
"""
Creates a single TFRecord file having batch_size examples.
Args:
chunk_files: list of filepaths to images
chunk_synsets: list of synsets corresponding to images in chunk_files
chunk_labels: list of integer labels corresponding to images in
chunk_files
output_filepath: Output tfrecord file
Returns None
"""
with tf.io.TFRecordWriter(output_filepath) as writer:
for i in range(len(chunk_files)):
image_str, height, width = _make_image(chunk_files[i], shard)
label = chunk_labels[i]
synset = chunk_synsets[i]
example = _make_example(
image_str, height, width, chunk_files[i], label, synset
)
writer.write(example.SerializeToString())
writer.close()
def make_tfrecs(
dataset_base_dir: str = '',
output_dir: str = '',
file_prefix: str = '',
synset_filepath: str = '',
batch_size: int = 1024,
logging_frequency: int = 1,
logging_gap: int = 3600,
shuffle: bool = True
):
"""
Only public function of the module. Makes TFReocrds and stores them in
output_dir. Each TFRecord except last one has exactly one batch of data.
Args:
dataset_base_dir: directory containing ImageNet-style directory
structure (synsets ID as directory names, images inside)
eg.: home/imagenet/train
output_dir: Directory to store TFRecords, eg: home/imagenet_tfrecs
file_prefix: prefix to be added tfrecords files
eg.: if file_prefix = 'train' then
all files look like: `train_0000_of_<num_shards>`
synset_filepath: path to synsets json file
batch_size: batch size of dataset. Each TFRecords, except the last one
will contain these many examples.
logging_frequency: 'Writing shard ..' will be logged to stdout after
these many shards are written.
logging_gap: Interval in seconds after which '<num> shards written in
<logging_gap> seconds' message will be printed.
shuffle: True if dataset needs to be shuffled
Returns None
"""
global _start_time, _logging_gap, _last_written_shards
if '' in (dataset_base_dir, output_dir, file_prefix):
raise ValueError("One or more of the arguments is None.")
if not os.path.exists(dataset_base_dir):
raise ValueError("Dataset path does not exist")
if not os.path.exists(output_dir):
raise ValueError("Output directory does not exist")
if synset_filepath is '':
synpath = _get_default_synset_path()
else:
synpath = synset_filepath
images, labels, synsets = _get_files(dataset_base_dir, synpath,
shuffle = shuffle)
print('Total images: ',len(images))
num_shards = int(math.ceil(len(images) / batch_size))
_start_time = time.time()
_logging_gap = logging_gap
for shard in range(num_shards):
if shard % logging_frequency == 0:
print("Writing %d of %d shards" % (shard, num_shards))
chunk_files = images[shard * batch_size : (shard + 1) * batch_size]
chunk_synsets = synsets[shard * batch_size : (shard + 1) * batch_size]
chunk_labels = labels[shard * batch_size : (shard + 1) * batch_size]
output_filepath = os.path.join(
output_dir,
file_prefix + "_%.4d_of_%.4d.tfrecord" % (shard, num_shards),
)
_make_single_tfrecord(
chunk_files, chunk_synsets, chunk_labels, output_filepath, shard
)
print("All shards written successfully!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment