Instantly share code, notes, and snippets.

Embed
What would you like to do?
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation
import scipy
import numpy as np
import argparse
parser = argparse.ArgumentParser(description='Decide if an image is a picture of a bird')
parser.add_argument('image', type=str, help='The image image file to check')
args = parser.parse_args()
# Same network definition as before
img_prep = ImagePreprocessing()
img_prep.add_featurewise_zero_center()
img_prep.add_featurewise_stdnorm()
img_aug = ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_rotation(max_angle=25.)
img_aug.add_random_blur(sigma_max=3.)
network = input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
network = conv_2d(network, 32, 3, activation='relu')
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam',
loss='categorical_crossentropy',
learning_rate=0.001)
model = tflearn.DNN(network, tensorboard_verbose=0, checkpoint_path='bird-classifier.tfl.ckpt')
model.load("bird-classifier.tfl.ckpt-50912")
# Load the image file
img = scipy.ndimage.imread(args.image, mode="RGB")
# Scale it to 32x32
img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')
# Predict
prediction = model.predict([img])
# Check the result.
is_bird = np.argmax(prediction[0]) == 1
if is_bird:
print("That's a bird!")
else:
print("That's not a bird!")
@tantalor

This comment has been minimized.

Copy link

tantalor commented Jun 14, 2016

Does this script need "bird-classifier.tfl.ckpt" to work?

@petercunha

This comment has been minimized.

Copy link

petercunha commented Jun 16, 2016

ayo fam how are we supposed to identify birds without that sweet sweet data set? gimme gimme please

@off99555

This comment has been minimized.

Copy link

off99555 commented Jul 2, 2016

You should also report the confidence percentage of the prediction at around line 56, report np.max(prediction[0])

@simama

This comment has been minimized.

Copy link

simama commented Jul 29, 2016

If somebody is still wondering, this script is a part of Adam's nice tutorial on machine learning, over on Medium. https://medium.com/@ageitgey/machine-learning-is-fun-80ea3ec3c471#.7pmvr7722.
You'll find a reference for this bird dataset in part 3.

@avanish

This comment has been minimized.

Copy link

avanish commented Oct 25, 2017

Hi, can you tell me what your Y dataset looks like? The way I did it was [[1., 0.], ... , [0., 1.]]. I'm always getting [0., 1.] as my prediction.

@harishyadav1465

This comment has been minimized.

Copy link

harishyadav1465 commented Dec 12, 2017

how to run can u explain

@adolfoyanes

This comment has been minimized.

Copy link

adolfoyanes commented Dec 12, 2017

in parser.add_argument('image', type=str, help='The image image file to check') does 'image' refers to the path to the image file?

@sahukk

This comment has been minimized.

Copy link

sahukk commented Jan 10, 2018

types of image file is not clear. Should it be jpeg, bmp, GIF, TIF for testing data ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment