Skip to content

Instantly share code, notes, and snippets.

@viralbthakar
Created October 26, 2020 03:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save viralbthakar/fc30b9ad92143cd20be1fa5d24bed43e to your computer and use it in GitHub Desktop.
Save viralbthakar/fc30b9ad92143cd20be1fa5d24bed43e to your computer and use it in GitHub Desktop.
Image Classification Example for Kaggle Dogs vs Cats Challenge
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Convolution2D, MaxPooling2D, Activation, Dense, Flatten
from tensorflow.keras.models import Model
#Set The Parametes
DATA_DIR = "./data"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
CLASS_LIST = ["cat", "dog"]
EXTENSION = ".jpg"
SPLIT_INDEX = 0.8
BATCH_SIZE = 32
INPUT_SHAPE = [256, 256, 3]
LEARNING_RATE = 0.0001
EPOCHS = 20
#Step - 2 - Split Data
#Function to extract the list of images for a particular class
def get_per_class_image_list(image_list, class_name):
class_name_image_list = [os.path.join(TRAIN_DIR, image_file) for image_file in image_list if image_file.split(".")[0] == class_name]
print("For class {} Found {} Images".format(class_name, len(class_name_image_list)))
return class_name_image_list
#Function to split the training and validation dataset
def split_data(image_list, class_list, split_index):
train_data_dict = {"images":[], "labels":[]}
val_data_dict = {"images":[], "labels":[]}
for i, class_name in enumerate(class_list):
class_image_list = get_per_class_image_list(image_list, class_name)
train_image_list = class_image_list[:int(len(class_image_list)*split_index)]
val_image_list = class_image_list[int(len(class_image_list)*split_index):]
train_label_list = [i for k in train_image_list]
val_label_list = [i for k in val_image_list]
train_data_dict["images"].extend(train_image_list)
train_data_dict["labels"].extend(train_label_list)
val_data_dict["images"].extend(val_image_list)
val_data_dict["labels"].extend(val_label_list)
return train_data_dict, val_data_dict
# Step - 3 - Build Data Pipeline
#Function to read, resize and scale the image
def get_img_file(img_path, input_shape):
image = tf.io.read_file(img_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [input_shape[0], input_shape[1]], antialias=True)
image = tf.cast(image, tf.float32)/255.0
return image
#Parser Function which return IMage and Label
def parse_function(ip_dict, input_shape):
label = ip_dict["labels"]
image = get_img_file(img_path=ip_dict["images"], input_shape=input_shape)
return image, label
#Main Data pipeline
def get_data_pipeline(data_dict, batch_size, input_shape):
total_images = len(data_dict["images"])
with tf.device("/cpu:0"):
dataset = tf.data.Dataset.from_tensor_slices(data_dict)
dataset = dataset.shuffle(total_images)
dataset = dataset.map(lambda ip_dict:parse_function(ip_dict, input_shape), num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=1)
return dataset
#Step 4 - Build the Model Architecture
def get_model_arch(input_shape, last_layer_activation='sigmoid'):
input_img = Input(input_shape, name='input')
x = Convolution2D(64, (3, 3), activation='relu', padding='same', name='fe0_conv1')(input_img)
x = Convolution2D(64, (3, 3), activation='relu', padding='same', name='fe0_conv2')(x)
x = MaxPooling2D((2, 2), padding='same', name='fe0_mp')(x)
x = Convolution2D(128, (3, 3), activation='relu', padding='same', name='fe1_conv1')(x)
x = Convolution2D(128, (3, 3), activation='relu', padding='same', name='fe1_conv2')(x)
x = MaxPooling2D((2, 2), padding='same', name='fe1_mp')(x)
x = Convolution2D(256, (3, 3), activation='relu', padding='same', name='fe2_conv1')(x)
x = Convolution2D(256, (3, 3), activation='relu', padding='same', name='fe2_conv2')(x)
x = MaxPooling2D((2, 2), padding='same', name='fe2_mp')(x)
x = Flatten(name='feature')(x)
x = Dense(100, activation='relu', name='fc0')(x)
x = Dense(10, activation='relu', name='fc1')(x)
logits = Dense(1, name='logits')(x)
probabilities = Activation(last_layer_activation)(logits)
model_arch = Model(inputs=input_img, outputs=probabilities)
return model_arch
#Spling The Data
all_image_list = [img_file_name for img_file_name in os.listdir(TRAIN_DIR) if os.path.splitext(img_file_name)[-1]==EXTENSION]
train_data_dict, val_data_dict = split_data(image_list=all_image_list, class_list=CLASS_LIST, split_index=SPLIT_INDEX)
print("-"*15, "DATA SUMMARY", "-"*15)
print("Total Files in {} Directory : {}".format(TRAIN_DIR, len(all_image_list)))
print("Total Train Images : {} and Labels : {}".format(len(train_data_dict["images"]), len(train_data_dict["labels"])))
print("Total Validation Images : {} and Labels : {}".format(len(val_data_dict["images"]), len(val_data_dict["labels"])))
#Building The Data Pipelines
train_data_pipeline = get_data_pipeline(data_dict=train_data_dict, batch_size=BATCH_SIZE, input_shape=INPUT_SHAPE)
val_data_pipeline = get_data_pipeline(data_dict=val_data_dict, batch_size=BATCH_SIZE, input_shape=INPUT_SHAPE)
#image_batch, label_batch = next(iter(train_data_pipeline))
#print("Image Batch Shape : {}".format(image_batch.numpy().shape))
#print("Label Batch : {}".format(label_batch.numpy()))
#Building the Model
loss = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam(lr=LEARNING_RATE)
metric = tf.keras.metrics.BinaryAccuracy(name='binaryAcc')
model = get_model_arch(input_shape=INPUT_SHAPE)
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
print(model.summary())
#Training The Model
history = model.fit(train_data_pipeline, epochs=EPOCHS,
validation_data=val_data_pipeline, shuffle=True, verbose=1)
model.save("Weights.h5")
#Plot the training meta data or stats.
def plot_metric_curve(history, metric, title):
plt.plot(history.history[metric])
plt.plot(history.history['val_'+metric])
plt.title(title)
plt.ylabel(metric)
plt.xlabel('Epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
plot_metric_curve(history, metric='loss', title="Loss Comparison")
plot_metric_curve(history, metric='binaryAcc', title="Binary Accuracy Comparison")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment