Skip to content

Instantly share code, notes, and snippets.

@Kwentar
Last active May 23, 2018 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Kwentar/47c853a0db9f98805110a04b6aaa4460 to your computer and use it in GitHub Desktop.
Save Kwentar/47c853a0db9f98805110a04b6aaa4460 to your computer and use it in GitHub Desktop.
Cats vs dogs script
from keras import Model
from keras import Sequential
from keras.applications import InceptionV3
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
import numpy as np
def get_model(count_classes):
inception_model = InceptionV3(pooling='max') # предобученная модель
for layer in inception_model.layers:
layer.trainable = False # не обучаем предпобученную модель
inception_out = inception_model.output
our_output = Dense(count_classes, activation='softmax')(inception_out)
result_model = Model(inputs=inception_model.input, outputs=our_output) # объявление модели
result_model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
return result_model
def train():
train_data_dir = 'data/train/'
validation_data_dir = 'data/valid'
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2, # сдвиг на +-20%
zoom_range=0.2, # масштабирование на +-20%
horizontal_flip=True) # отражение по горизонтали
test_datagen = ImageDataGenerator(rescale=1. / 255)
batch_size = 64
train_generator = train_datagen.flow_from_directory(
train_data_dir, # директория с тренировочными данными
target_size=(299, 299), # целевой размер картинок
batch_size=batch_size, # картинок за итерацию
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(299, 299),
batch_size=batch_size,
class_mode='categorical')
model = get_model(2)
model.fit_generator( # обучение модели
train_generator,
steps_per_epoch=20000//batch_size,
epochs=20,
validation_data=validation_generator,
validation_steps=5000//batch_size)
model.save_weights('my_model_weights.h5') # сохранение модели
def inference(file_name):
model = get_model(2)
model.load_weights('my_model_weights.h5')
img = np.array(image.load_img(file_name, target_size=(299, 299)))/255. # чтение изображения из файла
img = np.expand_dims(img, axis=0)
result = model.predict(img) # предсказание
return result[0]
train()
print(inference('data/train/cats/cat.1.jpg'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment