Skip to content

Instantly share code, notes, and snippets.

@amankharwal
Created November 8, 2020 06:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amankharwal/56d5cf3aae686d405921c3dd4c9a0ca2 to your computer and use it in GitHub Desktop.
Save amankharwal/56d5cf3aae686d405921c3dd4c9a0ca2 to your computer and use it in GitHub Desktop.
### Test on training set
batch_size = 16
errors = 0
good_preds = []
bad_preds = []
for it in range(int(np.ceil(len(validate)/batch_size))):
X_train, y_train = get_batch(validate, it*batch_size, batch_size)
result = model.predict(X_train)
cla = np.argmax(result, axis=1)
for idx, res in enumerate(result):
print("Class:", cla[idx], "- Confidence:", np.round(res[cla[idx]],2), "- GT:", y_train[idx])
if cla[idx] != y_train[idx]:
errors = errors + 1
bad_preds.append([batch_size*it + idx, cla[idx], res[cla[idx]]])
else:
good_preds.append([batch_size*it + idx, cla[idx], res[cla[idx]]])
print("Errors: ", errors, "Acc:", np.round(100*(len(validate)-errors)/len(validate),2))
#Good predictions
good_preds = np.array(good_preds)
good_preds = np.array(sorted(good_preds, key = lambda x: x[2], reverse=True))
fig=plt.figure(figsize=(16, 16))
for i in range(1,6):
n = int(good_preds[i,0])
img, lbl = get_image_from_number(n, validate)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
fig.add_subplot(1, 5, i)
plt.imshow(img)
lbl2 = np.array(int(good_preds[i,1])).reshape(1,1)
sample_cnt = list(df.landmark_id).count(lbl)
plt.title("Label: " + str(lbl) + "\nClassified as: " + str(decode_label(lbl2)) + "\nSamples in class " + str(lbl) + ": " + str(sample_cnt))
plt.axis('off')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment