Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save PhamBaTrungThanh/6d1f91fb1edb054bd740edc560c43f56 to your computer and use it in GitHub Desktop.
Save PhamBaTrungThanh/6d1f91fb1edb054bd740edc560c43f56 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import, division, print_function, unicode_literals
import pathlib
import os
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE = 32
IMG_HEIGHT = 120
IMG_WIDTH = 120
AUTOTUNE = tf.data.experimental.AUTOTUNE
DATA_DIR = 'data/6x6/trains'
data_dir = pathlib.Path(DATA_DIR)
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
img = tf.image.convert_image_dtype(img, tf.float32)
# resize the image to the desired size.
return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
# The second to last is the class-directory
return parts[-2] == CLASS_NAMES
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
# This is a small dataset, only load it once, and keep it in memory.
# use `.cache(filename)` to cache preprocessing work for datasets that don't
# fit in memory.
if cache:
if isinstance(cache, str):
ds = ds.cache(cache)
else:
ds = ds.cache()
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
# Repeat forever
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches in the background while the model
# is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
train_ds = prepare_for_training(labeled_ds)
image_batch, label_batch = next(iter(train_ds))
model = keras.Sequential([
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(image_batch, label_batch, epochs=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment