Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Updated to the Keras 2.0 API.
'''This script goes along the blog post
"Building powerful image classification models using very little data"
from blog.keras.io.
It uses data that can be downloaded at:
https://www.kaggle.com/c/dogs-vs-cats/data
In our setup, we:
- created a data/ folder
- created train/ and validation/ subfolders inside data/
- created cats/ and dogs/ subfolders inside train/ and validation/
- put the cat pictures index 0-999 in data/train/cats
- put the cat pictures index 1000-1400 in data/validation/cats
- put the dogs pictures index 12500-13499 in data/train/dogs
- put the dog pictures index 13500-13900 in data/validation/dogs
So that we have 1000 training examples for each class, and 400 validation examples for each class.
In summary, this is our directory structure:
```
data/
train/
dogs/
dog001.jpg
dog002.jpg
...
cats/
cat001.jpg
cat002.jpg
...
validation/
dogs/
dog001.jpg
dog002.jpg
...
cats/
cat001.jpg
cat002.jpg
...
```
'''
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
# dimensions of our images.
img_width, img_height = 150, 150
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
nb_train_samples = 2000
nb_validation_samples = 800
epochs = 50
batch_size = 16
if K.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
print("img data format channels_first")
else:
input_shape = (img_width, img_height, 3)
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary')
print("train_generator")
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary')
print("validation_generator")
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size)
print("fit_generator")
model.save('first_model.h5')
model.save_weights('first_weights.h5')
print("saved model and weights")
@ardianumam

This comment has been minimized.

Copy link

ardianumam commented Jun 8, 2017

Thanks @ritazh . This code and prediction code works perfectly.
I already know deep learning by auditing CNN course by Fei-Fei Li - Sandford Uni, but for Keras, I'm totally newcomers. This code is just really helpful :)
By the way, I also run code from this github and successfully generate model_file.h5 (without weight_file.h5), but when I use that model_file, I get error "No model found in config file". Once I change using your model_file.h5 generated using this code, it works perfectly. Do you know why?

Thank you in advance :)

@kumarvis

This comment has been minimized.

Copy link

kumarvis commented Jul 3, 2017

What is the use of test_datagen, when we are using it ? I have gone through the prediction part also https://gist.github.com/ritazh/a7c88875053c1106e407300fc4f1d8d6 but here also I cant find it .

@GitHubKay

This comment has been minimized.

Copy link

GitHubKay commented Jul 25, 2017

Hello,
how would i proceed if i would've one type of Image that i would like to classify. If the input image reaches 80 % similarity with the type of image i seek to find it would get classified as one of these images otherwise it won't?

For an example, i have 500 images i really like and have 10000 images which i don't know. How would i solve this problem with an CNN?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.