Skip to content

Instantly share code, notes, and snippets.

@CrackerHax
Last active June 3, 2019 13:48
Show Gist options
  • Save CrackerHax/06025c08bddf277696e26979b0b93e5d to your computer and use it in GitHub Desktop.
Save CrackerHax/06025c08bddf277696e26979b0b93e5d to your computer and use it in GitHub Desktop.
Create tfrecords labeled by category from directories of images
from random import shuffle
import glob
import sys
import cv2
import numpy as np
import tensorflow as tf
name = 'mountains' # name of your project directory where all images are
# should be under ./images/train (or change the path variable below)
image_size = 256 # size of images - should be square images (i.e. 256x256)
# this reads files organized by label to save in tfrecord form
# directories should look like this:
# images/train/portraits/male/old
# images/train/portraits/male/young
# images/train/portraits/female/old
# images/train/portraits/female/young
# leave these alone
path = 'images/train/'
addrs = []
labels = []
all_categories = []
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def load_image(addr):
# cv2 load images as BGR, convert it to RGB
img = cv2.imread(addr)
if img is None:
return None
img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def createDataRecord(out_filename, addrs, labels):
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(out_filename)
for i in range(len(addrs)):
print('Train data: {}/{}'.format(i, len(addrs)-1))
print('--- Path:'+addrs[i]+' Labels:'+str(labels[i]))
sys.stdout.flush()
# Load the image
img = load_image(addrs[i])
label = labels[i]
if img is None:
print("Error: no image")
continue
# Create a feature
feature = {
'image': _bytes_feature(img.tostring()),
'label': _int64_feature(label)
}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
#--
#index the labels
image_categories = [f.name for f in os.scandir(path+name) if f.is_dir() ]
for category in image_categories:
image_subcategories = [f.name for f in os.scandir(path+name+'/'+category) if f.is_dir() ]
all_categories = image_categories + image_subcategories
print(all_categories)
images = []
# get image paths and the (1 or 2) labels for each image
# there's probably a better way to do this recursively
if len(image_subcategories) == 0:
for category in image_categories:
file_path = path+name+'/'+category
images = [f.name for f in os.scandir(file_path) if f.is_file() ]
for f in images:
label = np.zeros((len(all_categories))).astype(int)
label[all_categories.index(category)] = 1
labels += [label]
addrs += [file_path+'/'+f]
else:
for category in image_categories:
for subcategory in image_subcategories:
file_path = path+name+'/'+category+'/'+subcategory
images = [f.name for f in os.scandir(file_path) if f.is_file() ]
for f in images:
label = np.zeros((len(all_categories))).astype(int)
label[all_categories.index(category)] = 1
label[all_categories.index(subcategory)] = 1
addrs += [file_path+'/'+f]
labels += [label]
# to shuffle data
c = list(zip(addrs, labels))
shuffle(c)
addrs, labels = zip(*c)
if not os.path.exists('datasets/'+name+'/'):
os.mkdir('datasets/'+name+'/')
createDataRecord('datasets/'+name+'/'+name+'_train.tfrecords', addrs, labels)
print('saved in ./datasets/'+name+'/')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment