Created
October 28, 2018 15:12
-
-
Save bacher09/87fb2c49339a2c68f32155f1d8129c71 to your computer and use it in GitHub Desktop.
OpenCV EAST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
import argparse | |
from collections import namedtuple | |
import time | |
def draw_rotated_rect(image, rect, color=(0, 255, 0), thickness=1): | |
vertices = np.round(rect.points()).astype(int) | |
for i in range(4): | |
cv2.line(image, tuple(vertices[i]), tuple(vertices[(i + 1) % 4]), color, thickness) | |
class RotatedRect(namedtuple("RotatedRect", ["center", "size", "angle"])): | |
__slots__ = () | |
def points(self): | |
# https://github.com/opencv/opencv/blob/808ba552c532408bddd5fe51784cf4209296448a/modules/core/src/types.cpp#L173 | |
rad_angle = np.deg2rad(self.angle) | |
b = np.cos(rad_angle) * 0.5 | |
a = np.sin(rad_angle) * 0.5 | |
p0 = np.array([ | |
self.center[0] - a * self.size[1] - b * self.size[0], | |
self.center[1] + b * self.size[1] - a * self.size[0] | |
]) | |
p1 = np.array([ | |
self.center[0] + a * self.size[1] - b * self.size[0], | |
self.center[1] - b * self.size[1] + a * self.size[0] | |
]) | |
p2 = self.center * 2 - p0 | |
p3 = self.center * 2 - p1 | |
return np.array([p0, p1, p2, p3]) | |
class EAST: | |
OUTPUTS = [ | |
"feature_fusion/Conv_7/Sigmoid", | |
"feature_fusion/concat_3" | |
] | |
def __init__(self, net_path, resolution=(320, 320)): | |
self.net = cv2.dnn.readNet(net_path) | |
self.resolution = resolution | |
def process(self, image, score_threshold=0.5, nms_threshold=0.4): | |
blob = self.img_to_blob(image, self.resolution) | |
h, w = image.shape[0], image.shape[1] | |
scale = np.array([w / self.resolution[0], h / self.resolution[1]]) | |
for rect, p in self.processBlob(blob, score_threshold, nms_threshold): | |
resized_rect = RotatedRect(rect.center * scale, rect.size * scale, rect.angle) | |
yield resized_rect, p | |
def processBlob(self, blob, score_threshold=0.5, nms_threshold=0.4): | |
self.net.setInput(blob) | |
scores, geometry = self.net.forward(self.OUTPUTS) | |
decode_iter = self.decodeDetections(scores, geometry, score_threshold) | |
boxes, confidences = [], [] | |
for rect, score in decode_iter: | |
boxes.append(rect) | |
confidences.append(score) | |
indices = cv2.dnn.NMSBoxesRotated(boxes, confidences, score_threshold, | |
nms_threshold) | |
if indices == (): | |
return | |
for i in indices.reshape(-1): | |
yield boxes[i], confidences[i] | |
def netTime(self): | |
return self.net.getPerfProfile()[0] / cv2.getTickFrequency() | |
@staticmethod | |
def img_to_blob(image, blob_shape=None): | |
if blob_shape is None: | |
blob_shape = image.shape[:2] | |
means = [123.68, 116.78, 103.94] | |
return cv2.dnn.blobFromImage(image, 1.0, blob_shape, means, True, False) | |
@staticmethod | |
def decodeDetections(scores, geometry, score_threshold): | |
""" | |
https://github.com/opencv/opencv/blob/master/samples/dnn/text_detection.cpp#L119 | |
""" | |
height, width = scores.shape[2], scores.shape[3] | |
point2f = lambda x, y: np.array([x, y], dtype=np.float64) | |
for y in range(height): | |
scores_data = scores[0, 0, y] | |
x0_data = geometry[0, 0, y] | |
x1_data = geometry[0, 1, y] | |
x2_data = geometry[0, 2, y] | |
x3_data = geometry[0, 3, y] | |
angles_data = geometry[0, 4, y] | |
for x in range(width): | |
score = scores_data[x] | |
if score < score_threshold: | |
continue | |
offset_x = x * 4. | |
offset_y = y * 4. | |
angle = angles_data[x] | |
cos_a = np.cos(angle) | |
sin_a = np.sin(angle) | |
h = x0_data[x] + x2_data[x] | |
w = x1_data[x] + x3_data[x] | |
offset = point2f( | |
offset_x + cos_a * x1_data[x] + sin_a * x2_data[x], | |
offset_y - sin_a * x1_data[x] + cos_a * x2_data[x] | |
) | |
p1 = point2f(-sin_a * h, -cos_a * h) + offset | |
p3 = point2f(-cos_a * w, sin_a * w) + offset | |
box_center = 0.5 * (p1 + p3) | |
box_angle = -angle * 180. / np.pi # use np.rad2deg | |
rect = RotatedRect(box_center, point2f(w, h), box_angle) | |
yield rect, float(score) | |
def main(): | |
def validate_size(value): | |
size = int(value) | |
if size % 32 != 0: | |
raise ValueError("Size should be multiple by 32") | |
parser = argparse.ArgumentParser( | |
description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of " | |
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)" | |
) | |
parser.add_argument( | |
"-m", "--model", type=str, required=True, | |
help="Path to a binary .pb file contains trained network." | |
) | |
parser.add_argument( | |
"-i", "--input", type=str, | |
help="Path to input image or video file. Skip this argument to capture frames from a camera." | |
) | |
parser.add_argument( | |
"--width", default=320, type=validate_size, | |
help="Preprocess input image by resizing to a specific width. It should be multiple by 32." | |
) | |
parser.add_argument( | |
"--height", default=320, type=validate_size, | |
help="Preprocess input image by resizing to a specific height. It should be multiple by 32." | |
) | |
parser.add_argument( | |
"--thr", type=float, default=0.5, | |
help="Confidence threshold." | |
) | |
parser.add_argument( | |
"--nms", type=float, default=0.4, | |
help="Non-maximum suppression threshold." | |
) | |
args = parser.parse_args() | |
cap = cv2.VideoCapture() | |
if args.input: | |
cap.open(args.input) | |
else: | |
cap.open(0) | |
winName = "EAST: An Efficient and Accurate Scene Text Detector" | |
cv2.namedWindow(winName) | |
east = EAST(args.model, (args.width, args.height)) | |
while cv2.waitKey(1) < 0: | |
ok, frame = cap.read() | |
if not ok: | |
cv2.waitKey() | |
break | |
for rect, _ in east.process(frame): | |
draw_rotated_rect(frame, rect) | |
label = "Inference time: {:.2f} s".format(east.netTime()) | |
cv2.putText(frame, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) | |
cv2.imshow(winName, frame) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment