Skip to content

Instantly share code, notes, and snippets.

@NickShargan
Last active October 21, 2017 22:12
Show Gist options
  • Save NickShargan/6708733982187c03b2da0a07f19f8854 to your computer and use it in GitHub Desktop.
Save NickShargan/6708733982187c03b2da0a07f19f8854 to your computer and use it in GitHub Desktop.
import os
import cv2
from tqdm import tqdm
import numpy as np
import pandas as pd
from keras.applications.mobilenet import MobileNet
from keras.models import Model
from keras.layers import Activation, GlobalAveragePooling2D, Dense
test_dir = "./DogBreed/test/"
sample_submission_path = "./sample_submission.csv"
base_model = MobileNet(input_shape=(224, 224, 3), include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(120, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions)
model.load_weights("./models/1_w_01_0.68.hdf5")
df_test = pd.read_csv(sample_submission_path)
# print(df_test.iloc[0])
dog_breeds = list(df_test.columns.values)[1:]
filenames = list(set(df_test.id))
# ger_shepard_path = "./german_shepherd/"
# filenames = os.listdir(ger_shepard_path)
for idx in tqdm(range(len(filenames))):
filename = df_test.loc[idx,"id"]
file_path = test_dir + filename + ".jpg"
# file_path = ger_shepard_path + filename + ".jpg"
if not os.path.isfile(file_path):
print("DEBUG: file doest exist: %s", file_path)
continue
img = cv2.imread(file_path)
img_in = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
img_in = img_in / 255.
img_in = np.reshape(img_in, (1, 224, 224, 3))
y = model.predict(img_in)
# print(filename)
# idx_max = np.argmax(y[0])
# print(dog_breeds[idx_max])
#
# cv2.imshow("img", img)
# cv2.waitKey()
# print(y[0])
df_test.loc[idx, 1:] = y[0]
print(df_test)
df_test.to_csv("subm_3.csv", index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment