Skip to content

Instantly share code, notes, and snippets.

@ksob
Created February 22, 2019 06:03
Show Gist options
  • Save ksob/6ac646104569fbbfeb156ef8fc0e6509 to your computer and use it in GitHub Desktop.
Save ksob/6ac646104569fbbfeb156ef8fc0e6509 to your computer and use it in GitHub Desktop.
predict.py
import json
import os
import sys
from glob import iglob
from pathlib import Path
import numpy as np
import tensorflow as tf
from keras.applications.xception import preprocess_input
from keras.engine.saving import load_model
from keras_preprocessing.image import load_img, img_to_array
from keras.backend.tensorflow_backend import set_session
# set working directory
os.chdir(sys.path[0])
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.logging.set_verbosity(tf.logging.ERROR)
sess = tf.Session(config=config)
set_session(sess)
with open("to_class.json", 'r') as f:
to_class = json.load(f)
IMG_WIDTH, IMG_HEIGHT = 299, 299
m = load_model("./models/train")
def get_files(root_dir):
list_of_files = []
for filename in iglob(root_dir + '/**', recursive=True):
if os.path.isfile(filename): # filter dirs
list_of_files.append(Path(filename))
return sorted(list_of_files)
def predict(file, model, to_class):
try:
im = load_img(file, target_size=(IMG_WIDTH, IMG_HEIGHT))
x = img_to_array(im)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
index = model.predict(x).argmax()
except:
return ''
return to_class[str(index)]
def categorize_dir(dir_str):
"""Categorize all files in given directory"""
root_path = Path("./images") / dir_str
for path in get_files(str(root_path)):
path_in_str = str(path)
print(path_in_str, "->", predict(path_in_str, m, to_class))
def look_for(root_path, real_cat, correct):
"""Look for given categories in directory"""
for path in get_files(str(root_path)):
path_in_str = str(path)
result = predict(path_in_str, m, to_class)
# If correct is false it search for any category except real_cat
if correct:
if result == real_cat:
print(path_in_str)
else:
if result != real_cat:
print(path_in_str, "->", result)
def categorize_percent_dir(dir_str, real_cat):
"""Return percent of given category in directory"""
correct = 0
root_path = Path("./images") / dir_str
pathlist = get_files(str(root_path))
for path in pathlist:
path_in_str = str(path)
result = predict(path_in_str, m, to_class)
if result == real_cat:
correct += 1
percentage = (correct/len(pathlist)) * 100
print(percentage, "%")
if __name__ == "__main__":
try:
path_to_images = Path(sys.argv[1])
look_for(path_to_images, 'plan', True)
except IndexError:
print("You have to write absolute path to images directory")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment