Skip to content

Instantly share code, notes, and snippets.

@bacher09
Created October 28, 2018 15:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bacher09/87fb2c49339a2c68f32155f1d8129c71 to your computer and use it in GitHub Desktop.
Save bacher09/87fb2c49339a2c68f32155f1d8129c71 to your computer and use it in GitHub Desktop.
OpenCV EAST
import cv2
import numpy as np
import argparse
from collections import namedtuple
import time
def draw_rotated_rect(image, rect, color=(0, 255, 0), thickness=1):
vertices = np.round(rect.points()).astype(int)
for i in range(4):
cv2.line(image, tuple(vertices[i]), tuple(vertices[(i + 1) % 4]), color, thickness)
class RotatedRect(namedtuple("RotatedRect", ["center", "size", "angle"])):
__slots__ = ()
def points(self):
# https://github.com/opencv/opencv/blob/808ba552c532408bddd5fe51784cf4209296448a/modules/core/src/types.cpp#L173
rad_angle = np.deg2rad(self.angle)
b = np.cos(rad_angle) * 0.5
a = np.sin(rad_angle) * 0.5
p0 = np.array([
self.center[0] - a * self.size[1] - b * self.size[0],
self.center[1] + b * self.size[1] - a * self.size[0]
])
p1 = np.array([
self.center[0] + a * self.size[1] - b * self.size[0],
self.center[1] - b * self.size[1] + a * self.size[0]
])
p2 = self.center * 2 - p0
p3 = self.center * 2 - p1
return np.array([p0, p1, p2, p3])
class EAST:
OUTPUTS = [
"feature_fusion/Conv_7/Sigmoid",
"feature_fusion/concat_3"
]
def __init__(self, net_path, resolution=(320, 320)):
self.net = cv2.dnn.readNet(net_path)
self.resolution = resolution
def process(self, image, score_threshold=0.5, nms_threshold=0.4):
blob = self.img_to_blob(image, self.resolution)
h, w = image.shape[0], image.shape[1]
scale = np.array([w / self.resolution[0], h / self.resolution[1]])
for rect, p in self.processBlob(blob, score_threshold, nms_threshold):
resized_rect = RotatedRect(rect.center * scale, rect.size * scale, rect.angle)
yield resized_rect, p
def processBlob(self, blob, score_threshold=0.5, nms_threshold=0.4):
self.net.setInput(blob)
scores, geometry = self.net.forward(self.OUTPUTS)
decode_iter = self.decodeDetections(scores, geometry, score_threshold)
boxes, confidences = [], []
for rect, score in decode_iter:
boxes.append(rect)
confidences.append(score)
indices = cv2.dnn.NMSBoxesRotated(boxes, confidences, score_threshold,
nms_threshold)
if indices == ():
return
for i in indices.reshape(-1):
yield boxes[i], confidences[i]
def netTime(self):
return self.net.getPerfProfile()[0] / cv2.getTickFrequency()
@staticmethod
def img_to_blob(image, blob_shape=None):
if blob_shape is None:
blob_shape = image.shape[:2]
means = [123.68, 116.78, 103.94]
return cv2.dnn.blobFromImage(image, 1.0, blob_shape, means, True, False)
@staticmethod
def decodeDetections(scores, geometry, score_threshold):
"""
https://github.com/opencv/opencv/blob/master/samples/dnn/text_detection.cpp#L119
"""
height, width = scores.shape[2], scores.shape[3]
point2f = lambda x, y: np.array([x, y], dtype=np.float64)
for y in range(height):
scores_data = scores[0, 0, y]
x0_data = geometry[0, 0, y]
x1_data = geometry[0, 1, y]
x2_data = geometry[0, 2, y]
x3_data = geometry[0, 3, y]
angles_data = geometry[0, 4, y]
for x in range(width):
score = scores_data[x]
if score < score_threshold:
continue
offset_x = x * 4.
offset_y = y * 4.
angle = angles_data[x]
cos_a = np.cos(angle)
sin_a = np.sin(angle)
h = x0_data[x] + x2_data[x]
w = x1_data[x] + x3_data[x]
offset = point2f(
offset_x + cos_a * x1_data[x] + sin_a * x2_data[x],
offset_y - sin_a * x1_data[x] + cos_a * x2_data[x]
)
p1 = point2f(-sin_a * h, -cos_a * h) + offset
p3 = point2f(-cos_a * w, sin_a * w) + offset
box_center = 0.5 * (p1 + p3)
box_angle = -angle * 180. / np.pi # use np.rad2deg
rect = RotatedRect(box_center, point2f(w, h), box_angle)
yield rect, float(score)
def main():
def validate_size(value):
size = int(value)
if size % 32 != 0:
raise ValueError("Size should be multiple by 32")
parser = argparse.ArgumentParser(
description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
)
parser.add_argument(
"-m", "--model", type=str, required=True,
help="Path to a binary .pb file contains trained network."
)
parser.add_argument(
"-i", "--input", type=str,
help="Path to input image or video file. Skip this argument to capture frames from a camera."
)
parser.add_argument(
"--width", default=320, type=validate_size,
help="Preprocess input image by resizing to a specific width. It should be multiple by 32."
)
parser.add_argument(
"--height", default=320, type=validate_size,
help="Preprocess input image by resizing to a specific height. It should be multiple by 32."
)
parser.add_argument(
"--thr", type=float, default=0.5,
help="Confidence threshold."
)
parser.add_argument(
"--nms", type=float, default=0.4,
help="Non-maximum suppression threshold."
)
args = parser.parse_args()
cap = cv2.VideoCapture()
if args.input:
cap.open(args.input)
else:
cap.open(0)
winName = "EAST: An Efficient and Accurate Scene Text Detector"
cv2.namedWindow(winName)
east = EAST(args.model, (args.width, args.height))
while cv2.waitKey(1) < 0:
ok, frame = cap.read()
if not ok:
cv2.waitKey()
break
for rect, _ in east.process(frame):
draw_rotated_rect(frame, rect)
label = "Inference time: {:.2f} s".format(east.netTime())
cv2.putText(frame, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
cv2.imshow(winName, frame)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment