Created
May 3, 2021 13:40
-
-
Save p-geon/a8a0229e120324181b7113c02cf01be2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
AUTOTUNE = tf.data.experimental.AUTOTUNE | |
import pathlib | |
import random | |
import os | |
import IPython.display as display | |
import matplotlib.pyplot as plt | |
class FlowerDataset: | |
def __init__(self): | |
self.data_root = self.download_data() | |
self.all_image_paths = self.get_img_paths() | |
self.label_names, self.all_image_labels = self.get_img_labels() | |
#self.show_example() | |
# | |
image_path = self.all_image_paths[0] | |
label = self.all_image_labels[0] | |
self.show_img(image_path, label) | |
print() | |
@staticmethod | |
def preprocess_image(image): | |
image = tf.image.decode_jpeg(image, channels=3) | |
image = tf.image.resize(image, [192, 192]) | |
image /= 255.0 # normalize to [0,1] range | |
return image | |
def load_and_preprocess_image(self, path): | |
image = tf.io.read_file(path) | |
return self.preprocess_image(image) | |
def show_img(self, img_path, label): | |
plt.imshow(self.load_and_preprocess_image(img_path)) | |
plt.grid(False) | |
plt.xlabel(self.caption_image(img_path)) | |
plt.title(self.label_names[label].title()) | |
# --- | |
# method | |
# --- | |
@staticmethod | |
def download_data(): | |
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', | |
fname='flower_photos', untar=True) | |
data_root = pathlib.Path(data_root_orig) # /root/.keras/datasets/flower_photos | |
# show dirs | |
for item in data_root.iterdir(): | |
print(item) | |
return data_root | |
def get_img_paths(self): | |
all_image_paths = list(self.data_root.glob('*/*')) | |
all_image_paths = [str(path) for path in all_image_paths] | |
random.shuffle(all_image_paths) | |
image_count = len(all_image_paths) | |
print(image_count) | |
all_image_paths[:10] | |
return all_image_paths | |
def get_img_labels(self): | |
label_names = sorted(item.name for item in self.data_root.glob('*/') if item.is_dir()) | |
# ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] | |
label_to_index = dict((name, index) for index,name in enumerate(label_names)) | |
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4} | |
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] | |
for path in self.all_image_paths] | |
print("First 10 labels indices: ", all_image_labels[:10]) | |
return label_names, all_image_labels | |
def show_example(self): | |
for n in range(3): | |
image_path = random.choice(self.all_image_paths) | |
display.display(display.Image(image_path)) | |
print(self.caption_image(image_path)) | |
print() | |
for n in range(3): | |
image_path = random.choice(self.all_image_paths) | |
display.display(display.Image(image_path)) | |
print(self.caption_image(image_path)) | |
print() | |
def caption_image(self, image_path): | |
attributions = (self.data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:] | |
attributions = [line.split(' CC-BY') for line in attributions] | |
attributions = dict(attributions) | |
image_rel = pathlib.Path(image_path).relative_to(self.data_root) | |
return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1]) | |
FlowerDataset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment