Skip to content

Instantly share code, notes, and snippets.

Created June 12, 2016 23:04
What would you like to do?
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation
import scipy
import numpy as np
import argparse
parser = argparse.ArgumentParser(description='Decide if an image is a picture of a bird')
parser.add_argument('image', type=str, help='The image image file to check')
args = parser.parse_args()
# Same network definition as before
img_prep = ImagePreprocessing()
img_aug = ImageAugmentation()
network = input_data(shape=[None, 32, 32, 3],
network = conv_2d(network, 32, 3, activation='relu')
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam',
model = tflearn.DNN(network, tensorboard_verbose=0, checkpoint_path='bird-classifier.tfl.ckpt')
# Load the image file
img = scipy.ndimage.imread(args.image, mode="RGB")
# Scale it to 32x32
img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')
# Predict
prediction = model.predict([img])
# Check the result.
is_bird = np.argmax(prediction[0]) == 1
if is_bird:
print("That's a bird!")
print("That's not a bird!")
Copy link

Does this script need "bird-classifier.tfl.ckpt" to work?

Copy link

ayo fam how are we supposed to identify birds without that sweet sweet data set? gimme gimme please

Copy link

off99555 commented Jul 2, 2016

You should also report the confidence percentage of the prediction at around line 56, report np.max(prediction[0])

Copy link

simicvm commented Jul 29, 2016

If somebody is still wondering, this script is a part of Adam's nice tutorial on machine learning, over on Medium.
You'll find a reference for this bird dataset in part 3.

Copy link

avanish commented Oct 25, 2017

Hi, can you tell me what your Y dataset looks like? The way I did it was [[1., 0.], ... , [0., 1.]]. I'm always getting [0., 1.] as my prediction.

Copy link

how to run can u explain

Copy link

in parser.add_argument('image', type=str, help='The image image file to check') does 'image' refers to the path to the image file?

Copy link

sahukk commented Jan 10, 2018

types of image file is not clear. Should it be jpeg, bmp, GIF, TIF for testing data ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment