Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Created July 18, 2019 18:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bigsnarfdude/a7b902d84eb328c7d3462b04ecf1844f to your computer and use it in GitHub Desktop.
Save bigsnarfdude/a7b902d84eb328c7d3462b04ecf1844f to your computer and use it in GitHub Desktop.
single_plate_tensorflow.py
import argparse
import time
import cv2
import numpy as np
import tensorflow as tf
from imutils import paths
from object_detection.utils import label_map_util
from base2designs.plates.plateFinder import PlateFinder
from base2designs.plates.predicter import Predicter
from base2designs.plates.plateDisplay import PlateDisplay
model_file = 'datasets_old/experiment_faster_rcnn/2018_08_02/exported_model/frozen_inference_graph.pb'
labels_file = 'classes/classes.pbtxt'
num_classes_file = 37
confidence_file = 0.1
images_file = 'images'
args = {}
args["pred_stages"] = 2
args["image_display"] = True
model = tf.Graph()
with model.as_default():
# initialize the graph definition
graphDef = tf.GraphDef()
# load the graph from disk
with tf.gfile.GFile(model_file, "rb") as f:
serializedGraph = f.read()
graphDef.ParseFromString(serializedGraph)
tf.import_graph_def(graphDef, name="")
labelMap = label_map_util.load_labelmap(labels_file)
categories = label_map_util.convert_label_map_to_categories(
labelMap, max_num_classes=num_classes_file,
use_display_name=True)
categoryIdx = label_map_util.create_category_index(categories)
plateFinder = PlateFinder(confidence_file, categoryIdx,
rejectPlates=False, charIOUMax=0.3)
plateDisplay = PlateDisplay()
# create a session to perform inference
with model.as_default():
with tf.Session(graph=model) as sess:
# create a predicter, used to predict plates and chars
predicter = Predicter(model, sess, categoryIdx)
imagePaths = paths.list_images(images_file)
frameCnt = 0
start_time = time.time()
# Loop over all the images
for imagePath in imagePaths:
frameCnt += 1
# load the image from disk
print("[INFO] Loading image \"{}\"".format(imagePath))
image = cv2.imread(imagePath)
(H, W) = image.shape[:2]
# If prediction stages == 2, then perform prediction on full image, find the plates, crop the plates from the image,
# and then perform prediction on the plate images
if args["pred_stages"] == 2:
# Perform inference on the full image, and then select only the plate boxes
boxes, scores, labels = predicter.predictPlates(image, preprocess=True)
licensePlateFound_pred, plateBoxes_pred, plateScores_pred = plateFinder.findPlatesOnly(boxes, scores, labels)
# loop over the plate boxes, find the chars inside the plate boxes,
# and then scrub the chars with 'processPlates', resulting in a list of final plateBoxes, char texts, char boxes, char scores and complete plate scores
plates = []
for plateBox in plateBoxes_pred:
boxes, scores, labels = predicter.predictChars(image, plateBox)
chars = plateFinder.findCharsOnly(boxes, scores, labels, plateBox, image.shape[0], image.shape[1])
if len(chars) > 0:
plates.append(chars)
else:
plates.append(None)
plateBoxes_pred, charTexts_pred, charBoxes_pred, charScores_pred, plateAverageScores_pred = plateFinder.processPlates(plates, plateBoxes_pred, plateScores_pred)
# If prediction stages == 1, then predict the plates and characters in one pass
elif args["pred_stages"] == 1:
# Perform inference on the full image, and then find the plate text associated with each plate
boxes, scores, labels = predicter.predictPlates(image, preprocess=False)
licensePlateFound_pred, plateBoxes_pred, charTexts_pred, charBoxes_pred, charScores_pred, plateScores_pred = plateFinder.findPlates(
boxes, scores, labels)
else:
print("[ERROR] --pred_stages {}. The number of prediction stages must be either 1 or 2".format(args["pred_stages"]))
quit()
# Print plate text
for charText in charTexts_pred:
print(" Found: ", charText)
# Display the full image with predicted plates and chars
if args["image_display"] == True:
imageLabelled = plateDisplay.labelImage(image, plateBoxes_pred, charBoxes_pred, charTexts_pred)
cv2.imshow("Labelled Image", imageLabelled)
cv2.waitKey(0)
# print some performance statistics
curTime = time.time()
processingTime = curTime - start_time
fps = frameCnt / processingTime
print("[INFO] Processed {} frames in {:.2f} seconds. Frame rate: {:.2f} Hz".format(frameCnt, processingTime, fps))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment