Last active
June 17, 2019 04:53
-
-
Save Aries0d0f/03af639588dead2c4eaaa898195b7081 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/env pyhton3 | |
# -*- coding: UTF-8 -*- | |
# Import basic modules | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
# Load Keras utilies & datasets | |
from keras.utils import np_utils | |
from keras.datasets import mnist | |
from keras.models import Sequential | |
from keras.layers import Dense | |
from keras.layers import Dropout | |
###### CONFIGURENCE PART START ###### | |
NUMBER_OF_DATA_TO_TRAIN = 25 # Int number, < 0 | |
EPOCHS = 20 # Int number, < 0 | |
HIDDEN_UNITS = 1000 # Int number, < 0 | |
SKETCH_NAMESPACE = 'sketch' | |
###### CONFIGURENCE PART END ###### | |
# Global variable for sketch | |
drawing = False # true if mouse is pressed | |
iX, iY = -1, -1 | |
# Load training datasets D(X) -> Y & test pair d(X) -> Y | |
(X_train_images, Y_train_labels), \ | |
(X_test_images, Y_test_labels) \ | |
= mnist.load_data() | |
# Func: Print training datasets pair in a range | |
def print_images_with_label(images, labels, prediction, start_index, number_of_data = 10, cmap_arg = 'binary'): | |
fig = plt.gcf() | |
fig.set_size_inches(8, 2 * (number_of_data / 5)) | |
for index in range(0, number_of_data): | |
sub_graph = plt.subplot(number_of_data / 5, 5, index + 1) | |
sub_graph.imshow(images[index], cmap = cmap_arg) | |
title = "[label]: " + str(labels[index + start_index]) + "\n" | |
if len(prediction) > 0: | |
title += "[predict]: " + str(prediction[index + start_index]) | |
sub_graph.set_title(title, fontsize = 10) | |
sub_graph.set_xticks([]) | |
sub_graph.set_yticks([]) | |
plt.show() | |
# Func: show training history | |
def show_training_history(training_history, train, validation, title): | |
plt.plot(training_history.history[train]) | |
plt.plot(train_history.history[validation]) | |
plt.title(title) | |
plt.xlabel('Epoch') | |
plt.ylabel(train) | |
plt.legend( | |
['train', 'validation'], | |
loc = 'upper left' | |
) | |
plt.show() | |
# Func: normalization data | |
def normalization(data, level): | |
return data / level | |
# Func: Mouse handler | |
def draw_sketch(event, X, Y, flags, param): | |
global iX, iY, drawing | |
if event == cv2.EVENT_LBUTTONDOWN: | |
drawing = True | |
iX, iY = X, Y | |
elif event == cv2.EVENT_MOUSEMOVE: | |
if drawing == True: | |
cv2.circle(sketch, (X, Y), 5, (255, 255, 255), -1) | |
elif event == cv2.EVENT_LBUTTONUP: | |
drawing = False | |
cv2.circle(sketch, (X, Y), 5, (255, 255, 255), -1) | |
# Print train data & test data | |
print_images_with_label(X_train_images, Y_train_labels, [], 0, NUMBER_OF_DATA_TO_TRAIN) | |
# Reshape binary image to vector | |
X_train_list = X_train_images.reshape(X_train_images.shape[0], X_train_images.shape[1] * X_train_images.shape[2]).astype('float32') | |
X_test_list = X_test_images.reshape(X_test_images.shape[0], X_test_images.shape[1] * X_test_images.shape[2]).astype('float32') | |
# # Print datasets vector | |
# print('[train data D(X)]: ', X_train_list.shape) | |
# print('[test data D\'(X)]: ', X_test_list.shape) | |
# Normalization | |
X_train_list_normalized = normalization(X_train_list, 255) | |
X_test_list_normalized = normalization(X_test_list, 255) | |
# Print normalized datasets | |
print('[train data D(X)]:\n', X_train_list_normalized) | |
print('[test data D\'(X)]:\n', X_test_list_normalized) | |
# One-hot encoding | |
Y_train_one_hot = np_utils.to_categorical(Y_train_labels) | |
Y_test_one_hot = np_utils.to_categorical(Y_test_labels) | |
# Create model | |
model = Sequential() | |
# Hidden Layer I | |
model.add(Dense( | |
units = HIDDEN_UNITS, | |
input_dim = X_train_images.shape[1] * X_train_images.shape[2], | |
init = 'normal', | |
activation = 'relu' | |
)) | |
# Hidden Layer II | |
model.add(Dense( | |
units = HIDDEN_UNITS, | |
init = 'normal', | |
activation = 'relu' | |
)) | |
# Add DropOut | |
model.add(Dropout(0.5)) | |
# Output Layer | |
model.add(Dense( | |
units = 10, | |
init = 'normal', | |
activation = 'softmax' | |
)) | |
# Print model summary | |
print(model.summary()) | |
# Compile model | |
model.compile( | |
loss = 'categorical_crossentropy', | |
optimizer = 'adam', | |
metrics = ['accuracy'] | |
) | |
# Training Start | |
train_history = model.fit( | |
x = X_train_list_normalized, | |
y = Y_train_one_hot, | |
validation_split = 0.1, | |
epochs = EPOCHS, | |
batch_size = 250, | |
verbose = 2 | |
) | |
# Print accuracy result | |
show_training_history(train_history, 'acc', 'val_acc', 'Training History: Accuracy') | |
# Print loss result | |
show_training_history(train_history, 'loss', 'val_loss', 'Training History: Lossy') | |
# Print model accuracy | |
scores = model.evaluate(X_test_list_normalized, Y_test_one_hot) | |
print('[Accuracy]: ', scores[1]) | |
# Run prediction | |
prediction = model.predict_classes(X_test_list) | |
# Print prediction result | |
print_images_with_label(X_test_images, Y_test_labels, prediction, 0, 25) | |
# Prepare sketch | |
sketch = np.zeros((256, 256, 3), np.uint8) | |
cv2.namedWindow(SKETCH_NAMESPACE) | |
cv2.setMouseCallback(SKETCH_NAMESPACE, draw_sketch) | |
# Real-time handwriting reconition | |
while(True): | |
cv2.imshow(SKETCH_NAMESPACE, sketch) | |
key = cv2.waitKey(1) & 0xFF | |
if key == 27: | |
break | |
if key == 13: | |
test_image = cv2.cvtColor( | |
cv2.resize( | |
sketch, | |
(X_train_images.shape[1], X_train_images.shape[2]) | |
), | |
cv2.COLOR_BGR2GRAY | |
).reshape(1, X_train_images.shape[1] * X_train_images.shape[2]).astype('float32') | |
out_prediction = model.predict_classes(test_image) | |
print(out_prediction[0]) | |
cv2.destroyAllWindows() | |
sketch = np.zeros((256, 256, 3), np.uint8) | |
cv2.namedWindow(SKETCH_NAMESPACE) | |
cv2.setMouseCallback(SKETCH_NAMESPACE, draw_sketch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment