Skip to content

Instantly share code, notes, and snippets.

@nb-programmer
Created August 11, 2022 15:21
Show Gist options
  • Save nb-programmer/6349c18e85717ce3862c12e5ff53845b to your computer and use it in GitHub Desktop.
Save nb-programmer/6349c18e85717ce3862c12e5ff53845b to your computer and use it in GitHub Desktop.
Simple script to show images from a folder, drag a select box to send to an image classifier model
'''
Image classification (primarily used for classifying cars) using OpenCV/DNN module.
Left-click and drag to select region and send image to classifier
Right-click to go to next image in folder
'''
import cv2
import numpy as np
import os
IMAGE_PATH = r'./images/'
MODEL_PATH = r"./car_classifier.onnx"
IMAGE_HEIGHT = 600
def image_getter(dir):
while True:
unloaded = []
for i in os.listdir(dir):
img = cv2.imread(os.path.join(dir, i))
unloaded.append(img is None)
if img is None: continue
#Fixed height
aspect = img.shape[1]/img.shape[0]
img = cv2.resize(img, (int(IMAGE_HEIGHT*aspect), IMAGE_HEIGHT))
yield img
else:
if all(unloaded):
break
curr_img=None
img_iter = image_getter(IMAGE_PATH)
def next_img():
global curr_img
try:
curr_img = next(img_iter)
except StopIteration:
print("No more images")
exit(1)
drag = {'is_dragging': False, 'clicked': False, 'dx': 0, 'dy': 0, 'x': 0, 'y': 0}
def mouse_click(event, x, y,
flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
drag['is_dragging'] = True
drag.update({'is_dragging': True, 'dx': x, 'dy': y})
if event == cv2.EVENT_LBUTTONUP:
drag.update({'is_dragging': False, 'clicked': True})
if event == cv2.EVENT_RBUTTONDOWN:
next_img()
drag.update({'x': x, 'y': y})
car_classifier = cv2.dnn.readNetFromONNX(MODEL_PATH)
car_classes = {0: 'No vehicle', 1: 'Vehicle'}
def classifyImg(img):
img = cv2.resize(img, (256,256))
nnIn = cv2.dnn.blobFromImage(img, scalefactor=0.3, size=(64,64))
car_classifier.setInput(nnIn)
res = car_classifier.forward()
winner = np.argmax(res, axis=1)[0]
win_acc = res[0][winner]
cv2.putText(img, "Class: {}({:.2f}%)".format(car_classes[winner], win_acc*100), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (192, 48, 16), 2)
cv2.imshow('Classifier', img)
cv2.namedWindow('Output')
cv2.setMouseCallback('Output', mouse_click)
next_img()
while True:
img = curr_img.copy()
if drag['is_dragging']:
cv2.rectangle(img, (drag['dx'], drag['dy']), (drag['x'], drag['y']), (255,0,0),2)
if drag['clicked']:
drag['clicked'] = False
rgn = (min(drag['dx'], drag['x']), min(drag['dy'], drag['y']), max(drag['dx'], drag['x']), max(drag['dy'], drag['y']))
crp = curr_img.copy()[rgn[1]:rgn[3],rgn[0]:rgn[2]]
classifyImg(crp)
cv2.imshow("Output", img)
k = cv2.waitKey(10)
if k == ord('q'): break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment