Skip to content

Instantly share code, notes, and snippets.

@rtao

rtao/Training.py Secret

Created July 26, 2017 05:24
Show Gist options
  • Save rtao/50eb8c96b06f4deddec2b7888da1d062 to your computer and use it in GitHub Desktop.
Save rtao/50eb8c96b06f4deddec2b7888da1d062 to your computer and use it in GitHub Desktop.
# coding: utf-8
# ## Transfer learning
# This notebook intended to try the transfer learning technique based on Google MobileNet model for defect checking dataset
# In[3]:
#import libraries
from __future__ import print_function
import numpy as np
np.random.seed(2048)
import os
import glob
import cv2
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import *
import warnings
warnings.filterwarnings('ignore')
import keras
from keras import applications
from keras import optimizers
from keras.models import Sequential, Model
from keras.layers import Dropout, Flatten, Dense
from keras import backend as k
from keras.callbacks import TensorBoard
from keras.applications.mobilenet import MobileNet
# In[4]:
# Number representation of labels
HONEYCOMBING_ID = 0
FORMWORK_ID = 1
# In[5]:
# Data parameters
image_size = (224, 224)
input_shape = (224, 224, 3)
test_size = 0.25
# In[6]:
def get_im_cv2(path):
img = cv2.imread(path)
resized = cv2.resize(img, image_size, cv2.INTER_LINEAR)
return resized
def load_train(path):
x_train = []
y_train = []
x_test = []
y_test = []
# Load the label csv
labels_df = pd.read_csv(path)
# Split data between training and validation
train_df, test_df = train_test_split(labels_df, test_size=test_size)
for idx, row in tqdm(train_df.iterrows()):
label = row.Defect
img_path = os.path.join('..', 'input', 'train', str(row.ID) + '.jpg')
img = get_im_cv2(img_path)
x_train.append(img)
# There just two labels
label_id = -1
if (label == 'Honeycombing'):
label_id = HONEYCOMBING_ID
elif (label == 'Formwork'):
label_id = FORMWORK_ID
y_train.append(label_id)
for idx, row in tqdm(test_df.iterrows()):
label = row.Defect
img_path = os.path.join('..', 'input', 'train', str(row.ID) + '.jpg')
img = get_im_cv2(img_path)
x_test.append(img)
# There just two labels
label_id = -1
if (label == 'Honeycombing'):
label_id = HONEYCOMBING_ID
elif (label == 'Formwork'):
label_id = FORMWORK_ID
y_test.append(label_id)
x_train = np.asarray(x_train)
y_train = np.asarray(y_train)
x_test = np.asarray(x_test)
y_test = np.asarray(y_test)
return x_train, y_train, x_test, y_test
# In[7]:
path = '../input/labels.csv'
x_train, y_train, x_test, y_test = load_train(path)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# In[8]:
# Training parameters
num_classes = 2
data_augmentation = False
# In[9]:
# Convert class vectors to binary class matrices.
Y_train = keras.utils.to_categorical(y_train, num_classes)
Y_test = keras.utils.to_categorical(y_test, num_classes)
X_train = x_train.astype('float32')
X_test = x_test.astype('float32')
X_train /= 255
X_test /= 255
# In[10]:
# Load mobilenet without the classification layer
model = MobileNet(weights=None, include_top=True, input_shape=input_shape, classes=2)
# Show the summary
model.summary()
# In[13]:
model.compile(loss='categorical_crossentropy',
optimizer = optimizers.SGD(lr=0.0001, momentum=0.9),
metrics=["accuracy"])
# In[14]:
batch_size = 32
epochs = 50
verbose = 1
# Train the new model
history = model.fit(X_train, Y_train,
epochs=epochs, batch_size=batch_size,
validation_data=(X_test, Y_test),
verbose=verbose, shuffle=True)
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# In[15]:
model.save('/notebooks/keras_mobilenet_defect.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment