Skip to content

Instantly share code, notes, and snippets.

@nahidalam
Created February 3, 2022 20:58
Show Gist options
  • Save nahidalam/38bb8d4677440d17ff020ffb0c2ea009 to your computer and use it in GitHub Desktop.
Save nahidalam/38bb8d4677440d17ff020ffb0c2ea009 to your computer and use it in GitHub Desktop.
import numpy as np, pandas as pd, os
import matplotlib.pyplot as plt,
import cv2
import tensorflow as tf, re, math
PATH = 'data/train/'
IMGS = os.listdir(PATH)
SIZE = (len(IMGS) // 5) + 1 # split images into 5 files
IMAGE_SIZE = [256, 256]
print(f'Image samples: {len(IMGS)}')
def decode_image(image, HEIGHT, WIDTH, CHANNELS):
image = tf.image.decode_jpeg(image, channels=CHANNELS)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
return image
def read_tfrecord(example, HEIGHT, WIDTH, CHANNELS):
tfrecord_format = {
'image': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image'], HEIGHT, WIDTH, CHANNELS)
return image
def load_dataset(filenames, HEIGHT, WIDTH, CHANNELS=3):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(lambda example: read_tfrecord(example, HEIGHT, WIDTH, CHANNELS), num_parallel_calls=AUTO)
return dataset
def display_samples(ds, row, col):
ds_iter = iter(ds)
plt.figure(figsize=(15, int(15*row/col)))
for j in range(row*col):
example_sample = next(ds_iter)
plt.subplot(row,col,j+1)
plt.axis('off')
plt.imshow(example_sample[0] * 0.5 + 0.5)
plt.show()
def count_data_items(filenames):
n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
return np.sum(n)
# Create TF Records
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def serialize_example(image):
feature = {
'image': _bytes_feature(image),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
count = len(IMGS)//SIZE + int(len(IMGS)%SIZE!=0)
for j in range(count):
print(); print('Writing TFRecord %i of %i...'%(j,count))
count2 = min(SIZE,len(IMGS)-j*SIZE)
with tf.io.TFRecordWriter('art%.2i-%i.tfrec'%(j,count2)) as writer:
for k in range(count2):
img = cv2.imread(PATH+IMGS[SIZE*j+k])
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 94))[1].tostring()
name = IMGS[SIZE*j+k].split('.')[0]
example = serialize_example(img)
writer.write(example)
if k%100==0: print(k,', ',end='')
# validate
FILENAMES = tf.io.gfile.glob('art*.tfrec')
print(f'TFRecords files: {FILENAMES}')
print(f'Created image samples: {count_data_items(FILENAMES)}')
display_samples(load_dataset(FILENAMES, *IMAGE_SIZE).batch(1), 10, 10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment