Skip to content

Instantly share code, notes, and snippets.

@p-geon
Created May 3, 2021 13:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save p-geon/a8a0229e120324181b7113c02cf01be2 to your computer and use it in GitHub Desktop.
Save p-geon/a8a0229e120324181b7113c02cf01be2 to your computer and use it in GitHub Desktop.
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE
import pathlib
import random
import os
import IPython.display as display
import matplotlib.pyplot as plt
class FlowerDataset:
def __init__(self):
self.data_root = self.download_data()
self.all_image_paths = self.get_img_paths()
self.label_names, self.all_image_labels = self.get_img_labels()
#self.show_example()
#
image_path = self.all_image_paths[0]
label = self.all_image_labels[0]
self.show_img(image_path, label)
print()
@staticmethod
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192, 192])
image /= 255.0 # normalize to [0,1] range
return image
def load_and_preprocess_image(self, path):
image = tf.io.read_file(path)
return self.preprocess_image(image)
def show_img(self, img_path, label):
plt.imshow(self.load_and_preprocess_image(img_path))
plt.grid(False)
plt.xlabel(self.caption_image(img_path))
plt.title(self.label_names[label].title())
# ---
# method
# ---
@staticmethod
def download_data():
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig) # /root/.keras/datasets/flower_photos
# show dirs
for item in data_root.iterdir():
print(item)
return data_root
def get_img_paths(self):
all_image_paths = list(self.data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
print(image_count)
all_image_paths[:10]
return all_image_paths
def get_img_labels(self):
label_names = sorted(item.name for item in self.data_root.glob('*/') if item.is_dir())
# ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
label_to_index = dict((name, index) for index,name in enumerate(label_names))
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
for path in self.all_image_paths]
print("First 10 labels indices: ", all_image_labels[:10])
return label_names, all_image_labels
def show_example(self):
for n in range(3):
image_path = random.choice(self.all_image_paths)
display.display(display.Image(image_path))
print(self.caption_image(image_path))
print()
for n in range(3):
image_path = random.choice(self.all_image_paths)
display.display(display.Image(image_path))
print(self.caption_image(image_path))
print()
def caption_image(self, image_path):
attributions = (self.data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
attributions = [line.split(' CC-BY') for line in attributions]
attributions = dict(attributions)
image_rel = pathlib.Path(image_path).relative_to(self.data_root)
return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])
FlowerDataset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment