Created June 21, 2021 05:27
"""Script to make and save TFRecords from ImageNet files"""
import tensorflow as tf
import os
import random
import json
import math
import time
from .image_utils import *
from typing import Tuple, List
_last_written_shards = 0
_start_time = time.time()
_logging_gap = 0
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = (
) # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# make_tfrecs -> for(_make_single_tfrecord)-> make_dataset -> _make_example
def _get_synset_labels(filepath: str) -> dict:
Gets synsets from json file in a dict
filepath: json file path
Dict having the following structure:
{str id : (int synset_ID, str label_name )}
with open(filepath, "r") as f:
raw_labels_dict = json.load(f)
labels_dict = dict()
for i in raw_labels_dict:
labels_dict[raw_labels_dict[i]["id"]] = (
return labels_dict
def _get_default_synset_path() -> str:
self_path = __file__
path_segments = self_path.split('/')
regnety_path = '/'.join(path_segments[:-2])
return os.path.join(regnety_path, 'config', 'imagenet_synset_to_human.json')
def _get_files(
data_dir: str,
synset_filepath: str,
shuffle: bool = True) -> Tuple[
Returns lists of all files, their integer labels and their synsets
data_dir: directory containing ImageNet-style directory structure
(synsets ID as directory names, images inside)
synset_filepath: path to synsets json file
shuffle: True if data needs to be shuffled
all_images: paths to all image files
all_labels: integer labels corresponding to images in all_images list
all_synsets: synset strings corresponding to images in all_images list
all_images =, "*", "*.JPEG"))
all_synsets = [os.path.basename(os.path.dirname(f)) for f in all_images]
all_indexes = list(range(len(all_images)))
if shuffle:
all_images = [all_images[i] for i in all_indexes]
all_synsets = [all_synsets[i][1:] + "-n" for i in all_indexes]
labels_dict = _get_synset_labels(synset_filepath)
all_labels_int = [labels_dict[i][0] for i in all_synsets]
return all_images, all_labels_int, all_synsets
def _make_image(filepath: str, shard: int) -> Tuple[str, int, int]:
Reads an image and returns its raw byte string. Converts all images to JPEG
filepath: path to .JPEG image
A byte string of the image with JPEG RGB format.
global _start_time, _logging_gap, _last_written_shards
if time.time() - _start_time >= _logging_gap:
print('%d shards were completed in the last %d seconds.' %
(shard - _last_written_shards, _logging_gap))
_last_written_shards = shard
_start_time = time.time()
image_str =
if is_png(filepath):
image_str = png_to_jpeg(image_str)
if is_cmyk(filepath):
image_str = cmyk_to_rgb(image_str)
image_tensor =
height, width = image_tensor.shape[0], image_tensor.shape[1]
if not is_rgb(image_tensor):
image_tensor = tf.image.grayscale_to_rgb(image_tensor)
image_str =
assert len(image_tensor.shape) == 3
return image_str, height, width
def _make_example(
image_str: bytes,
height: int,
width: int,
filepath: str,
label: int,
synset: str) -> tf.train.Example:
Makes a single example from arguments
image_str: bytes string of image in JPEG RGB format
filepath: path to image
height: height of image in pixels
width: width of image in pixels
label: integer denoting label
synset: synset string corresponding to image
A tf.train.Example having aforementioned attributes
example = tf.train.Example(
"image": _bytes_feature(image_str),
"height": _int64_feature(height),
"width": _int64_feature(width),
"filename": _bytes_feature(
"label": _int64_feature(label),
"synset": _bytes_feature(bytes(synset).encode("utf8")),
except TypeError:
example = tf.train.Example(
"image": _bytes_feature(image_str),
"height": _int64_feature(height),
"width": _int64_feature(width),
"filename": _bytes_feature(
bytes(os.path.basename(filepath), encoding="utf8")
"label": _int64_feature(label),
"synset": _bytes_feature(bytes(synset, encoding="utf8")),
return example
def _make_single_tfrecord(
chunk_files: List[str],
chunk_synsets: List[str],
chunk_labels: List[int],
output_filepath: str,
shard: int
Creates a single TFRecord file having batch_size examples.
chunk_files: list of filepaths to images
chunk_synsets: list of synsets corresponding to images in chunk_files
chunk_labels: list of integer labels corresponding to images in
output_filepath: Output tfrecord file
Returns None
with as writer:
for i in range(len(chunk_files)):
image_str, height, width = _make_image(chunk_files[i], shard)
label = chunk_labels[i]
synset = chunk_synsets[i]
example = _make_example(
image_str, height, width, chunk_files[i], label, synset
def make_tfrecs(
dataset_base_dir: str = '',
output_dir: str = '',
file_prefix: str = '',
synset_filepath: str = '',
batch_size: int = 1024,
logging_frequency: int = 1,
logging_gap: int = 3600,
shuffle: bool = True
Only public function of the module. Makes TFReocrds and stores them in
output_dir. Each TFRecord except last one has exactly one batch of data.
dataset_base_dir: directory containing ImageNet-style directory
structure (synsets ID as directory names, images inside)
eg.: home/imagenet/train
output_dir: Directory to store TFRecords, eg: home/imagenet_tfrecs
file_prefix: prefix to be added tfrecords files
eg.: if file_prefix = 'train' then
all files look like: `train_0000_of_<num_shards>`
synset_filepath: path to synsets json file
batch_size: batch size of dataset. Each TFRecords, except the last one
will contain these many examples.
logging_frequency: 'Writing shard ..' will be logged to stdout after
these many shards are written.
logging_gap: Interval in seconds after which '<num> shards written in
<logging_gap> seconds' message will be printed.
shuffle: True if dataset needs to be shuffled
Returns None
global _start_time, _logging_gap, _last_written_shards
if '' in (dataset_base_dir, output_dir, file_prefix):
raise ValueError("One or more of the arguments is None.")
if not os.path.exists(dataset_base_dir):
raise ValueError("Dataset path does not exist")
if not os.path.exists(output_dir):
raise ValueError("Output directory does not exist")
if synset_filepath is '':
synpath = _get_default_synset_path()
synpath = synset_filepath
images, labels, synsets = _get_files(dataset_base_dir, synpath,
shuffle = shuffle)
print('Total images: ',len(images))
num_shards = int(math.ceil(len(images) / batch_size))
_start_time = time.time()
_logging_gap = logging_gap
for shard in range(num_shards):
if shard % logging_frequency == 0:
print("Writing %d of %d shards" % (shard, num_shards))
chunk_files = images[shard * batch_size : (shard + 1) * batch_size]
chunk_synsets = synsets[shard * batch_size : (shard + 1) * batch_size]
chunk_labels = labels[shard * batch_size : (shard + 1) * batch_size]
output_filepath = os.path.join(
file_prefix + "_%.4d_of_%.4d.tfrecord" % (shard, num_shards),
chunk_files, chunk_synsets, chunk_labels, output_filepath, shard
print("All shards written successfully!")
