Last active
November 1, 2022 09:31
-
-
Save carlosh93/ed0ce914285eebd365d7446f41a19912 to your computer and use it in GitHub Desktop.
Tensorflow Load Images in Dataset #python #tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://www.tensorflow.org/tutorials/load_data/images | |
import tensorflow as tf | |
AUTOTUNE = tf.data.experimental.AUTOTUNE | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import pathlib | |
data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', | |
fname='flower_photos', untar=True) | |
data_dir = pathlib.Path(data_dir) | |
image_count = len(list(data_dir.glob('*/*.jpg'))) | |
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]) | |
roses = list(data_dir.glob('roses/*')) | |
BATCH_SIZE = 32 | |
IMG_HEIGHT = 224 | |
IMG_WIDTH = 224 | |
STEPS_PER_EPOCH = np.ceil(image_count/BATCH_SIZE) | |
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*')) | |
def get_label(file_path): | |
# convert the path to a list of path components | |
parts = tf.strings.split(file_path, os.path.sep) | |
# The second to last is the class-directory | |
return parts[-2] == CLASS_NAMES | |
def decode_img(img): | |
# convert the compressed string to a 3D uint8 tensor | |
img = tf.image.decode_jpeg(img, channels=3) | |
# Use `convert_image_dtype` to convert to floats in the [0,1] range. | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
# resize the image to the desired size. | |
return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT]) | |
def process_path(file_path): | |
label = get_label(file_path) | |
# load the raw data from the file as a string | |
img = tf.io.read_file(file_path) | |
img = decode_img(img) | |
return img, label | |
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel. | |
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE) | |
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000): | |
# This is a small dataset, only load it once, and keep it in memory. | |
# use `.cache(filename)` to cache preprocessing work for datasets that don't | |
# fit in memory. | |
if cache: | |
if isinstance(cache, str): | |
ds = ds.cache(cache) | |
else: | |
ds = ds.cache() | |
ds = ds.shuffle(buffer_size=shuffle_buffer_size) | |
# Repeat forever | |
ds = ds.repeat() | |
ds = ds.batch(BATCH_SIZE) | |
# `prefetch` lets the dataset fetch batches in the background while the model | |
# is training. | |
ds = ds.prefetch(buffer_size=AUTOTUNE) | |
return ds | |
def show_batch(image_batch, label_batch): | |
plt.figure(figsize=(10,10)) | |
for n in range(25): | |
ax = plt.subplot(5,5,n+1) | |
plt.imshow(image_batch[n]) | |
plt.title(CLASS_NAMES[label_batch[n]==1][0].title()) | |
plt.axis('off') | |
# Go Training | |
train_ds = prepare_for_training(labeled_ds) | |
image_batch, label_batch = next(iter(train_ds)) | |
show_batch(image_batch.numpy(), label_batch.numpy()) | |
# Beging training |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment