Created
August 11, 2022 15:21
-
-
Save nb-programmer/6349c18e85717ce3862c12e5ff53845b to your computer and use it in GitHub Desktop.
Simple script to show images from a folder, drag a select box to send to an image classifier model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Image classification (primarily used for classifying cars) using OpenCV/DNN module. | |
Left-click and drag to select region and send image to classifier | |
Right-click to go to next image in folder | |
''' | |
import cv2 | |
import numpy as np | |
import os | |
IMAGE_PATH = r'./images/' | |
MODEL_PATH = r"./car_classifier.onnx" | |
IMAGE_HEIGHT = 600 | |
def image_getter(dir): | |
while True: | |
unloaded = [] | |
for i in os.listdir(dir): | |
img = cv2.imread(os.path.join(dir, i)) | |
unloaded.append(img is None) | |
if img is None: continue | |
#Fixed height | |
aspect = img.shape[1]/img.shape[0] | |
img = cv2.resize(img, (int(IMAGE_HEIGHT*aspect), IMAGE_HEIGHT)) | |
yield img | |
else: | |
if all(unloaded): | |
break | |
curr_img=None | |
img_iter = image_getter(IMAGE_PATH) | |
def next_img(): | |
global curr_img | |
try: | |
curr_img = next(img_iter) | |
except StopIteration: | |
print("No more images") | |
exit(1) | |
drag = {'is_dragging': False, 'clicked': False, 'dx': 0, 'dy': 0, 'x': 0, 'y': 0} | |
def mouse_click(event, x, y, | |
flags, param): | |
if event == cv2.EVENT_LBUTTONDOWN: | |
drag['is_dragging'] = True | |
drag.update({'is_dragging': True, 'dx': x, 'dy': y}) | |
if event == cv2.EVENT_LBUTTONUP: | |
drag.update({'is_dragging': False, 'clicked': True}) | |
if event == cv2.EVENT_RBUTTONDOWN: | |
next_img() | |
drag.update({'x': x, 'y': y}) | |
car_classifier = cv2.dnn.readNetFromONNX(MODEL_PATH) | |
car_classes = {0: 'No vehicle', 1: 'Vehicle'} | |
def classifyImg(img): | |
img = cv2.resize(img, (256,256)) | |
nnIn = cv2.dnn.blobFromImage(img, scalefactor=0.3, size=(64,64)) | |
car_classifier.setInput(nnIn) | |
res = car_classifier.forward() | |
winner = np.argmax(res, axis=1)[0] | |
win_acc = res[0][winner] | |
cv2.putText(img, "Class: {}({:.2f}%)".format(car_classes[winner], win_acc*100), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (192, 48, 16), 2) | |
cv2.imshow('Classifier', img) | |
cv2.namedWindow('Output') | |
cv2.setMouseCallback('Output', mouse_click) | |
next_img() | |
while True: | |
img = curr_img.copy() | |
if drag['is_dragging']: | |
cv2.rectangle(img, (drag['dx'], drag['dy']), (drag['x'], drag['y']), (255,0,0),2) | |
if drag['clicked']: | |
drag['clicked'] = False | |
rgn = (min(drag['dx'], drag['x']), min(drag['dy'], drag['y']), max(drag['dx'], drag['x']), max(drag['dy'], drag['y'])) | |
crp = curr_img.copy()[rgn[1]:rgn[3],rgn[0]:rgn[2]] | |
classifyImg(crp) | |
cv2.imshow("Output", img) | |
k = cv2.waitKey(10) | |
if k == ord('q'): break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment