-
-
Save rtao/50eb8c96b06f4deddec2b7888da1d062 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# ## Transfer learning | |
# This notebook intended to try the transfer learning technique based on Google MobileNet model for defect checking dataset | |
# In[3]: | |
#import libraries | |
from __future__ import print_function | |
import numpy as np | |
np.random.seed(2048) | |
import os | |
import glob | |
import cv2 | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from sklearn.model_selection import train_test_split | |
from tqdm import * | |
import warnings | |
warnings.filterwarnings('ignore') | |
import keras | |
from keras import applications | |
from keras import optimizers | |
from keras.models import Sequential, Model | |
from keras.layers import Dropout, Flatten, Dense | |
from keras import backend as k | |
from keras.callbacks import TensorBoard | |
from keras.applications.mobilenet import MobileNet | |
# In[4]: | |
# Number representation of labels | |
HONEYCOMBING_ID = 0 | |
FORMWORK_ID = 1 | |
# In[5]: | |
# Data parameters | |
image_size = (224, 224) | |
input_shape = (224, 224, 3) | |
test_size = 0.25 | |
# In[6]: | |
def get_im_cv2(path): | |
img = cv2.imread(path) | |
resized = cv2.resize(img, image_size, cv2.INTER_LINEAR) | |
return resized | |
def load_train(path): | |
x_train = [] | |
y_train = [] | |
x_test = [] | |
y_test = [] | |
# Load the label csv | |
labels_df = pd.read_csv(path) | |
# Split data between training and validation | |
train_df, test_df = train_test_split(labels_df, test_size=test_size) | |
for idx, row in tqdm(train_df.iterrows()): | |
label = row.Defect | |
img_path = os.path.join('..', 'input', 'train', str(row.ID) + '.jpg') | |
img = get_im_cv2(img_path) | |
x_train.append(img) | |
# There just two labels | |
label_id = -1 | |
if (label == 'Honeycombing'): | |
label_id = HONEYCOMBING_ID | |
elif (label == 'Formwork'): | |
label_id = FORMWORK_ID | |
y_train.append(label_id) | |
for idx, row in tqdm(test_df.iterrows()): | |
label = row.Defect | |
img_path = os.path.join('..', 'input', 'train', str(row.ID) + '.jpg') | |
img = get_im_cv2(img_path) | |
x_test.append(img) | |
# There just two labels | |
label_id = -1 | |
if (label == 'Honeycombing'): | |
label_id = HONEYCOMBING_ID | |
elif (label == 'Formwork'): | |
label_id = FORMWORK_ID | |
y_test.append(label_id) | |
x_train = np.asarray(x_train) | |
y_train = np.asarray(y_train) | |
x_test = np.asarray(x_test) | |
y_test = np.asarray(y_test) | |
return x_train, y_train, x_test, y_test | |
# In[7]: | |
path = '../input/labels.csv' | |
x_train, y_train, x_test, y_test = load_train(path) | |
print(x_train.shape[0], 'train samples') | |
print(x_test.shape[0], 'test samples') | |
# In[8]: | |
# Training parameters | |
num_classes = 2 | |
data_augmentation = False | |
# In[9]: | |
# Convert class vectors to binary class matrices. | |
Y_train = keras.utils.to_categorical(y_train, num_classes) | |
Y_test = keras.utils.to_categorical(y_test, num_classes) | |
X_train = x_train.astype('float32') | |
X_test = x_test.astype('float32') | |
X_train /= 255 | |
X_test /= 255 | |
# In[10]: | |
# Load mobilenet without the classification layer | |
model = MobileNet(weights=None, include_top=True, input_shape=input_shape, classes=2) | |
# Show the summary | |
model.summary() | |
# In[13]: | |
model.compile(loss='categorical_crossentropy', | |
optimizer = optimizers.SGD(lr=0.0001, momentum=0.9), | |
metrics=["accuracy"]) | |
# In[14]: | |
batch_size = 32 | |
epochs = 50 | |
verbose = 1 | |
# Train the new model | |
history = model.fit(X_train, Y_train, | |
epochs=epochs, batch_size=batch_size, | |
validation_data=(X_test, Y_test), | |
verbose=verbose, shuffle=True) | |
# summarize history for accuracy | |
plt.plot(history.history['acc']) | |
plt.plot(history.history['val_acc']) | |
plt.title('model accuracy') | |
plt.ylabel('accuracy') | |
plt.xlabel('epoch') | |
plt.legend(['train', 'test'], loc='upper left') | |
plt.show() | |
# summarize history for loss | |
plt.plot(history.history['loss']) | |
plt.plot(history.history['val_loss']) | |
plt.title('model loss') | |
plt.ylabel('loss') | |
plt.xlabel('epoch') | |
plt.legend(['train', 'test'], loc='upper left') | |
plt.show() | |
# In[15]: | |
model.save('/notebooks/keras_mobilenet_defect.h5') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment