Last active
June 25, 2017 09:59
-
-
Save kikuchy/9595db231ab5b323fb945b905ad43ac4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.applications.inception_v3 import InceptionV3 | |
from keras.applications.inception_v3 import preprocess_input | |
from keras.models import Sequential, Model | |
from keras.layers import Dense, Dropout, Activation, Flatten | |
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, AveragePooling2D | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau | |
from keras.optimizers import SGD | |
from keras.regularizers import l2 | |
import matplotlib.image as mpimg | |
from scipy.misc import imresize | |
import numpy as np | |
import keras.backend as K | |
import math | |
K.clear_session() | |
img_size=299 | |
train_data_dir = "data/train" | |
validation_data_dir = "data/validation" | |
#訓練データ拡張 | |
train_datagen = ImageDataGenerator( | |
featurewise_center=False, | |
samplewise_center=False, | |
featurewise_std_normalization=False, | |
samplewise_std_normalization=False, | |
rotation_range=10, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
horizontal_flip=True, | |
vertical_flip=False, | |
zoom_range=[.8, 1], | |
channel_shift_range=30, | |
fill_mode='reflect') | |
test_datagen = ImageDataGenerator() | |
train_generator = train_datagen.flow_from_directory( | |
train_data_dir, | |
target_size=(img_size, img_size), | |
batch_size=64, | |
class_mode="binary") | |
test_generator = test_datagen.flow_from_directory( | |
validation_data_dir, | |
target_size=(img_size, img_size), | |
batch_size=32, | |
class_mode = "binary") | |
#Inception v3モデルの読み込み。最終層は読み込まない | |
base_model = InceptionV3(weights='imagenet', include_top=False) | |
#最終層の設定 | |
x = base_model.output | |
x = GlobalAveragePooling2D()(x) | |
predictions = Dense(1, kernel_initializer="glorot_uniform", activation="sigmoid", kernel_regularizer=l2(.0005))(x) | |
model = Model(inputs=base_model.input, outputs=predictions) | |
#base_modelはweightsを更新しない | |
for layer in base_model.layers: | |
layer.trainable = False | |
opt = SGD(lr=.01, momentum=.9) | |
model.compile(optimizer=opt, loss='binary_crossentropy', metrics=['accuracy']) | |
checkpointer = ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.hdf5', verbose=1, save_best_only=True) | |
csv_logger = CSVLogger('model.log') | |
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, | |
patience=5, min_lr=0.001) | |
history = model.fit_generator(train_generator, | |
steps_per_epoch=2000, | |
epochs=10, | |
validation_data=test_generator, | |
validation_steps=800, | |
verbose=1, | |
callbacks=[reduce_lr, csv_logger, checkpointer]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import load_model | |
from keras.preprocessing import image | |
import numpy as np | |
img_size=299 | |
model = load_model(filepath='./model.00-0.06.hdf5') | |
def predict_img(img_name): | |
img_arr = image.img_to_array(image.load_img(img_name, target_size=(img_size, img_size))) | |
x=img_arr | |
x = np.expand_dims(x, axis=0) | |
y_pred = model.predict(x) | |
if y_pred <0.5: | |
print (y_pred[0][0], 'cat') | |
else: | |
print (y_pred[0][0], 'dog') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment