-
-
Save solaris33/b1eb323648a624e9eb827e8cb9641c51 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding:utf-8 | |
import glob | |
import sys | |
import csv | |
import cv2 | |
import time | |
import os | |
import argparse | |
import itertools | |
from multiprocessing import Pool | |
import threading | |
import numpy as np | |
import scipy.optimize | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as Patches | |
from shapely.geometry import Polygon | |
import tensorflow as tf | |
def get_images(data_path): | |
files = [] | |
idx = 0 | |
for ext in ['jpg', 'png', 'jpeg', 'JPG']: | |
files.extend(glob.glob( | |
os.path.join(data_path, '*.{}'.format(ext)))) | |
idx += 1 | |
return files | |
def load_annotation(p): | |
''' | |
load annotation from the text file | |
:param p: | |
:return: | |
''' | |
text_polys = [] | |
text_tags = [] | |
if not os.path.exists(p): | |
return np.array(text_polys, dtype=np.float32) | |
with open(p, 'r') as f: | |
reader = csv.reader(f) | |
for line in reader: | |
label = line[-1] | |
# strip BOM. \ufeff for python3, \xef\xbb\bf for python2 | |
line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] | |
line = line[0].split(' ') | |
x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) | |
text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) | |
if label == '*' or label == '###': | |
text_tags.append(True) | |
else: | |
text_tags.append(False) | |
return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) | |
def polygon_area(poly): | |
''' | |
compute area of a polygon | |
:param poly: | |
:return: | |
''' | |
edge = [ | |
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), | |
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), | |
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), | |
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) | |
] | |
return np.sum(edge)/2. | |
def check_and_validate_polys(FLAGS, polys, tags, size): | |
''' | |
check so that the text poly is in the same direction, | |
and also filter some invalid polygons | |
:param polys: | |
:param tags: | |
:return: | |
''' | |
(h, w) = size | |
if polys.shape[0] == 0: | |
return polys | |
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w-1) | |
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h-1) | |
validated_polys = [] | |
validated_tags = [] | |
for poly, tag in zip(polys, tags): | |
p_area = polygon_area(poly) | |
if abs(p_area) < 1: | |
# print poly | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print('invalid poly') | |
continue | |
if p_area > 0: | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print('poly in wrong direction') | |
poly = poly[(0, 3, 2, 1), :] | |
validated_polys.append(poly) | |
validated_tags.append(tag) | |
return np.array(validated_polys), np.array(validated_tags) | |
def crop_area(FLAGS, im, polys, tags, crop_background=False, max_tries=50): | |
''' | |
make random crop from the input image | |
:param im: | |
:param polys: | |
:param tags: | |
:param crop_background: | |
:param max_tries: | |
:return: | |
''' | |
h, w, _ = im.shape | |
pad_h = h//10 | |
pad_w = w//10 | |
h_array = np.zeros((h + pad_h*2), dtype=np.int32) | |
w_array = np.zeros((w + pad_w*2), dtype=np.int32) | |
for poly in polys: | |
poly = np.round(poly, decimals=0).astype(np.int32) | |
minx = np.min(poly[:, 0]) | |
maxx = np.max(poly[:, 0]) | |
w_array[minx+pad_w:maxx+pad_w] = 1 | |
miny = np.min(poly[:, 1]) | |
maxy = np.max(poly[:, 1]) | |
h_array[miny+pad_h:maxy+pad_h] = 1 | |
# ensure the cropped area not across a text | |
h_axis = np.where(h_array == 0)[0] | |
w_axis = np.where(w_array == 0)[0] | |
if len(h_axis) == 0 or len(w_axis) == 0: | |
return im, polys, tags | |
for i in range(max_tries): | |
xx = np.random.choice(w_axis, size=2) | |
xmin = np.min(xx) - pad_w | |
xmax = np.max(xx) - pad_w | |
xmin = np.clip(xmin, 0, w-1) | |
xmax = np.clip(xmax, 0, w-1) | |
yy = np.random.choice(h_axis, size=2) | |
ymin = np.min(yy) - pad_h | |
ymax = np.max(yy) - pad_h | |
ymin = np.clip(ymin, 0, h-1) | |
ymax = np.clip(ymax, 0, h-1) | |
if xmax - xmin < FLAGS.min_crop_side_ratio*w or ymax - ymin < FLAGS.min_crop_side_ratio*h: | |
# area too small | |
continue | |
if polys.shape[0] != 0: | |
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ | |
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) | |
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] | |
else: | |
selected_polys = [] | |
if len(selected_polys) == 0: | |
# no text in this area | |
if crop_background: | |
return im[ymin:ymax+1, xmin:xmax+1, :], polys[selected_polys], tags[selected_polys] | |
else: | |
continue | |
im = im[ymin:ymax+1, xmin:xmax+1, :] | |
polys = polys[selected_polys] | |
tags = tags[selected_polys] | |
polys[:, :, 0] -= xmin | |
polys[:, :, 1] -= ymin | |
return im, polys, tags | |
return im, polys, tags | |
def shrink_poly(poly, r): | |
''' | |
fit a poly inside the origin poly, maybe bugs here... | |
used for generating the score map | |
:param poly: the text poly | |
:param r: r in the paper | |
:return: the shrinked poly | |
''' | |
# shrink ratio | |
R = 0.3 | |
# find the longer pair | |
if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \ | |
np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]): | |
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) | |
## p0, p1 | |
theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) | |
poly[0][0] += R * r[0] * np.cos(theta) | |
poly[0][1] += R * r[0] * np.sin(theta) | |
poly[1][0] -= R * r[1] * np.cos(theta) | |
poly[1][1] -= R * r[1] * np.sin(theta) | |
## p2, p3 | |
theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) | |
poly[3][0] += R * r[3] * np.cos(theta) | |
poly[3][1] += R * r[3] * np.sin(theta) | |
poly[2][0] -= R * r[2] * np.cos(theta) | |
poly[2][1] -= R * r[2] * np.sin(theta) | |
## p0, p3 | |
theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) | |
poly[0][0] += R * r[0] * np.sin(theta) | |
poly[0][1] += R * r[0] * np.cos(theta) | |
poly[3][0] -= R * r[3] * np.sin(theta) | |
poly[3][1] -= R * r[3] * np.cos(theta) | |
## p1, p2 | |
theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) | |
poly[1][0] += R * r[1] * np.sin(theta) | |
poly[1][1] += R * r[1] * np.cos(theta) | |
poly[2][0] -= R * r[2] * np.sin(theta) | |
poly[2][1] -= R * r[2] * np.cos(theta) | |
else: | |
## p0, p3 | |
# print poly | |
theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) | |
poly[0][0] += R * r[0] * np.sin(theta) | |
poly[0][1] += R * r[0] * np.cos(theta) | |
poly[3][0] -= R * r[3] * np.sin(theta) | |
poly[3][1] -= R * r[3] * np.cos(theta) | |
## p1, p2 | |
theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) | |
poly[1][0] += R * r[1] * np.sin(theta) | |
poly[1][1] += R * r[1] * np.cos(theta) | |
poly[2][0] -= R * r[2] * np.sin(theta) | |
poly[2][1] -= R * r[2] * np.cos(theta) | |
## p0, p1 | |
theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) | |
poly[0][0] += R * r[0] * np.cos(theta) | |
poly[0][1] += R * r[0] * np.sin(theta) | |
poly[1][0] -= R * r[1] * np.cos(theta) | |
poly[1][1] -= R * r[1] * np.sin(theta) | |
## p2, p3 | |
theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) | |
poly[3][0] += R * r[3] * np.cos(theta) | |
poly[3][1] += R * r[3] * np.sin(theta) | |
poly[2][0] -= R * r[2] * np.cos(theta) | |
poly[2][1] -= R * r[2] * np.sin(theta) | |
return poly | |
def point_dist_to_line(p1, p2, p3): | |
# compute the distance from p3 to p1-p2 | |
return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) | |
def fit_line(p1, p2): | |
# fit a line ax+by+c = 0 | |
if p1[0] == p1[1]: | |
return [1., 0., -p1[0]] | |
else: | |
[k, b] = np.polyfit(p1, p2, deg=1) | |
return [k, -1., b] | |
def line_cross_point(FLAGS, line1, line2): | |
# line1 0= ax+by+c, compute the cross point of line1 and line2 | |
if line1[0] != 0 and line1[0] == line2[0]: | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print('Cross point does not exist') | |
return None | |
if line1[0] == 0 and line2[0] == 0: | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print('Cross point does not exist') | |
return None | |
if line1[1] == 0: | |
x = -line1[2] | |
y = line2[0] * x + line2[2] | |
elif line2[1] == 0: | |
x = -line2[2] | |
y = line1[0] * x + line1[2] | |
else: | |
k1, _, b1 = line1 | |
k2, _, b2 = line2 | |
x = -(b1-b2)/(k1-k2) | |
y = k1*x + b1 | |
return np.array([x, y], dtype=np.float32) | |
def line_verticle(line, point): | |
# get the verticle line from line across point | |
if line[1] == 0: | |
verticle = [0, -1, point[1]] | |
else: | |
if line[0] == 0: | |
verticle = [1, 0, -point[0]] | |
else: | |
verticle = [-1./line[0], -1, point[1] - (-1/line[0] * point[0])] | |
return verticle | |
def rectangle_from_parallelogram(FLAGS, poly): | |
''' | |
fit a rectangle from a parallelogram | |
:param poly: | |
:return: | |
''' | |
p0, p1, p2, p3 = poly | |
angle_p0 = np.arccos(np.dot(p1-p0, p3-p0)/(np.linalg.norm(p0-p1) * np.linalg.norm(p3-p0))) | |
if angle_p0 < 0.5 * np.pi: | |
if np.linalg.norm(p0 - p1) > np.linalg.norm(p0-p3): | |
# p0 and p2 | |
## p0 | |
p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) | |
p2p3_verticle = line_verticle(p2p3, p0) | |
new_p3 = line_cross_point(FLAGS, p2p3, p2p3_verticle) | |
## p2 | |
p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) | |
p0p1_verticle = line_verticle(p0p1, p2) | |
new_p1 = line_cross_point(FLAGS, p0p1, p0p1_verticle) | |
return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) | |
else: | |
p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) | |
p1p2_verticle = line_verticle(p1p2, p0) | |
new_p1 = line_cross_point(FLAGS, p1p2, p1p2_verticle) | |
p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) | |
p0p3_verticle = line_verticle(p0p3, p2) | |
new_p3 = line_cross_point(FLAGS, p0p3, p0p3_verticle) | |
return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) | |
else: | |
if np.linalg.norm(p0-p1) > np.linalg.norm(p0-p3): | |
# p1 and p3 | |
## p1 | |
p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) | |
p2p3_verticle = line_verticle(p2p3, p1) | |
new_p2 = line_cross_point(FLAGS, p2p3, p2p3_verticle) | |
## p3 | |
p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) | |
p0p1_verticle = line_verticle(p0p1, p3) | |
new_p0 = line_cross_point(FLAGS, p0p1, p0p1_verticle) | |
return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) | |
else: | |
p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) | |
p0p3_verticle = line_verticle(p0p3, p1) | |
new_p0 = line_cross_point(FLAGS, p0p3, p0p3_verticle) | |
p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) | |
p1p2_verticle = line_verticle(p1p2, p3) | |
new_p2 = line_cross_point(FLAGS, p1p2, p1p2_verticle) | |
return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) | |
def sort_rectangle(FLAGS, poly): | |
# sort the four coordinates of the polygon, points in poly should be sorted clockwise | |
# First find the lowest point | |
p_lowest = np.argmax(poly[:, 1]) | |
if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2: | |
# if the bottom line is parallel to x-axis, then p0 must be the upper-left corner | |
p0_index = np.argmin(np.sum(poly, axis=1)) | |
p1_index = (p0_index + 1) % 4 | |
p2_index = (p0_index + 2) % 4 | |
p3_index = (p0_index + 3) % 4 | |
return poly[[p0_index, p1_index, p2_index, p3_index]], 0. | |
else: | |
# find the point that sits right to the lowest point | |
p_lowest_right = (p_lowest - 1) % 4 | |
p_lowest_left = (p_lowest + 1) % 4 | |
angle = np.arctan(-(poly[p_lowest][1] - poly[p_lowest_right][1])/(poly[p_lowest][0] - poly[p_lowest_right][0])) | |
# assert angle > 0 | |
if angle <= 0: | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print(angle, poly[p_lowest], poly[p_lowest_right]) | |
if angle/np.pi * 180 > 45: | |
# 这个点为p2 - this point is p2 | |
p2_index = p_lowest | |
p1_index = (p2_index - 1) % 4 | |
p0_index = (p2_index - 2) % 4 | |
p3_index = (p2_index + 1) % 4 | |
return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi/2 - angle) | |
else: | |
# 这个点为p3 - this point is p3 | |
p3_index = p_lowest | |
p0_index = (p3_index + 1) % 4 | |
p1_index = (p3_index + 2) % 4 | |
p2_index = (p3_index + 3) % 4 | |
return poly[[p0_index, p1_index, p2_index, p3_index]], angle | |
def restore_rectangle_rbox(origin, geometry): | |
d = geometry[:, :4] | |
angle = geometry[:, 4] | |
# for angle > 0 | |
origin_0 = origin[angle >= 0] | |
d_0 = d[angle >= 0] | |
angle_0 = angle[angle >= 0] | |
if origin_0.shape[0] > 0: | |
p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2], | |
d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2], | |
d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]), | |
np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]), | |
d_0[:, 3], -d_0[:, 2]]) | |
p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 | |
rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0)) | |
rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 | |
rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0)) | |
rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) | |
p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 | |
p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 | |
p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 | |
p3_in_origin = origin_0 - p_rotate[:, 4, :] | |
new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 | |
new_p1 = p_rotate[:, 1, :] + p3_in_origin | |
new_p2 = p_rotate[:, 2, :] + p3_in_origin | |
new_p3 = p_rotate[:, 3, :] + p3_in_origin | |
new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], | |
new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 | |
else: | |
new_p_0 = np.zeros((0, 4, 2)) | |
# for angle < 0 | |
origin_1 = origin[angle < 0] | |
d_1 = d[angle < 0] | |
angle_1 = angle[angle < 0] | |
if origin_1.shape[0] > 0: | |
p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2], | |
np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2], | |
np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]), | |
-d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]), | |
-d_1[:, 1], -d_1[:, 2]]) | |
p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 | |
rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0)) | |
rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 | |
rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0)) | |
rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) | |
p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 | |
p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 | |
p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 | |
p3_in_origin = origin_1 - p_rotate[:, 4, :] | |
new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 | |
new_p1 = p_rotate[:, 1, :] + p3_in_origin | |
new_p2 = p_rotate[:, 2, :] + p3_in_origin | |
new_p3 = p_rotate[:, 3, :] + p3_in_origin | |
new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], | |
new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 | |
else: | |
new_p_1 = np.zeros((0, 4, 2)) | |
return np.concatenate([new_p_0, new_p_1]) | |
def restore_rectangle(origin, geometry): | |
return restore_rectangle_rbox(origin, geometry) | |
def generate_rbox(FLAGS, im_size, polys, tags): | |
h, w = im_size | |
shrinked_poly_mask = np.zeros((h, w), dtype=np.uint8) | |
orig_poly_mask = np.zeros((h, w), dtype=np.uint8) | |
score_map = np.zeros((h, w), dtype=np.uint8) | |
geo_map = np.zeros((h, w, 5), dtype=np.float32) | |
# mask used during traning, to ignore some hard areas | |
overly_small_text_region_training_mask = np.ones((h, w), dtype=np.uint8) | |
for poly_idx, poly_data in enumerate(zip(polys, tags)): | |
poly = poly_data[0] | |
tag = poly_data[1] | |
r = [None, None, None, None] | |
for i in range(4): | |
r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]), | |
np.linalg.norm(poly[i] - poly[(i - 1) % 4])) | |
# score map | |
shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :] | |
cv2.fillPoly(score_map, shrinked_poly, 1) | |
cv2.fillPoly(shrinked_poly_mask, shrinked_poly, poly_idx + 1) | |
cv2.fillPoly(orig_poly_mask, poly.astype(np.int32)[np.newaxis, :, :], 1) | |
# if the poly is too small, then ignore it during training | |
poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])) | |
poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])) | |
if min(poly_h, poly_w) < FLAGS.min_text_size: | |
cv2.fillPoly(overly_small_text_region_training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) | |
if tag: | |
cv2.fillPoly(overly_small_text_region_training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) | |
xy_in_poly = np.argwhere(shrinked_poly_mask == (poly_idx + 1)) | |
# if geometry == 'RBOX': | |
# generate a parallelogram for any combination of two vertices | |
fitted_parallelograms = [] | |
for i in range(4): | |
p0 = poly[i] | |
p1 = poly[(i + 1) % 4] | |
p2 = poly[(i + 2) % 4] | |
p3 = poly[(i + 3) % 4] | |
edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) | |
backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) | |
forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) | |
if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3): | |
# parallel lines through p2 | |
if edge[1] == 0: | |
edge_opposite = [1, 0, -p2[0]] | |
else: | |
edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]] | |
else: | |
# after p3 | |
if edge[1] == 0: | |
edge_opposite = [1, 0, -p3[0]] | |
else: | |
edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]] | |
# move forward edge | |
new_p0 = p0 | |
new_p1 = p1 | |
new_p2 = p2 | |
new_p3 = p3 | |
new_p2 = line_cross_point(FLAGS, forward_edge, edge_opposite) | |
if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3): | |
# across p0 | |
if forward_edge[1] == 0: | |
forward_opposite = [1, 0, -p0[0]] | |
else: | |
forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]] | |
else: | |
# across p3 | |
if forward_edge[1] == 0: | |
forward_opposite = [1, 0, -p3[0]] | |
else: | |
forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]] | |
new_p0 = line_cross_point(FLAGS, forward_opposite, edge) | |
new_p3 = line_cross_point(FLAGS, forward_opposite, edge_opposite) | |
fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) | |
# or move backward edge | |
new_p0 = p0 | |
new_p1 = p1 | |
new_p2 = p2 | |
new_p3 = p3 | |
new_p3 = line_cross_point(FLAGS, backward_edge, edge_opposite) | |
if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2): | |
# across p1 | |
if backward_edge[1] == 0: | |
backward_opposite = [1, 0, -p1[0]] | |
else: | |
backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]] | |
else: | |
# across p2 | |
if backward_edge[1] == 0: | |
backward_opposite = [1, 0, -p2[0]] | |
else: | |
backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]] | |
new_p1 = line_cross_point(FLAGS, backward_opposite, edge) | |
new_p2 = line_cross_point(FLAGS, backward_opposite, edge_opposite) | |
fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) | |
areas = [Polygon(t).area for t in fitted_parallelograms] | |
parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32) | |
# sort thie polygon | |
parallelogram_coord_sum = np.sum(parallelogram, axis=1) | |
min_coord_idx = np.argmin(parallelogram_coord_sum) | |
parallelogram = parallelogram[ | |
[min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]] | |
rectange = rectangle_from_parallelogram(FLAGS, parallelogram) | |
rectange, rotate_angle = sort_rectangle(FLAGS, rectange) | |
p0_rect, p1_rect, p2_rect, p3_rect = rectange | |
for y, x in xy_in_poly: | |
point = np.array([x, y], dtype=np.float32) | |
# top | |
geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point) | |
# right | |
geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point) | |
# down | |
geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point) | |
# left | |
geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point) | |
# angle | |
geo_map[y, x, 4] = rotate_angle | |
shrinked_poly_mask = (shrinked_poly_mask > 0).astype('uint8') | |
text_region_boundary_training_mask = 1 - (orig_poly_mask - shrinked_poly_mask) | |
return score_map, geo_map, overly_small_text_region_training_mask, text_region_boundary_training_mask | |
def all(iterable): | |
for element in iterable: | |
if not element: | |
return False | |
return True | |
def get_text_file(image_file): | |
txt_file = image_file.replace(os.path.basename(image_file).split('.')[1], 'txt') | |
txt_file_name = txt_file.split('/')[-1] | |
#txt_file = txt_file.replace(txt_file_name, 'gt_' + txt_file_name) | |
return txt_file | |
def pad_image(img, input_size, is_train): | |
new_h, new_w, _ = img.shape | |
max_h_w_i = np.max([new_h, new_w, input_size]) | |
img_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) | |
if is_train: | |
shift_h = np.random.randint(max_h_w_i - new_h + 1) | |
shift_w = np.random.randint(max_h_w_i - new_w + 1) | |
else: | |
shift_h = (max_h_w_i - new_h) // 2 | |
shift_w = (max_h_w_i - new_w) // 2 | |
img_padded[shift_h:new_h+shift_h, shift_w:new_w+shift_w, :] = img.copy() | |
img = img_padded | |
return img, shift_h, shift_w | |
def resize_image(img, text_polys, input_size, shift_h, shift_w): | |
new_h, new_w, _ = img.shape | |
img = cv2.resize(img, dsize=(input_size, input_size)) | |
# pad and resize text polygons | |
resize_ratio_3_x = input_size/float(new_w) | |
resize_ratio_3_y = input_size/float(new_h) | |
text_polys[:, :, 0] += shift_w | |
text_polys[:, :, 1] += shift_h | |
text_polys[:, :, 0] *= resize_ratio_3_x | |
text_polys[:, :, 1] *= resize_ratio_3_y | |
return img, text_polys | |
class threadsafe_iter: | |
"""Takes an iterator/generator and makes it thread-safe by | |
serializing call to the `next` method of given iterator/generator. | |
""" | |
def __init__(self, it): | |
self.it = it | |
self.lock = threading.Lock() | |
def __iter__(self): | |
return self | |
def __next__(self): # Python 3 | |
with self.lock: | |
return next(self.it) | |
def next(self): # Python 2 | |
with self.lock: | |
return self.it.next() | |
def threadsafe_generator(f): | |
"""A decorator that takes a generator function and makes it thread-safe. | |
""" | |
def g(*a, **kw): | |
return threadsafe_iter(f(*a, **kw)) | |
return g | |
@threadsafe_generator | |
def generator(FLAGS, input_size=512, background_ratio=3./8, is_train=True, idx=None, random_scale=np.array([0.5, 1, 2.0, 3.0]), vis=False): | |
image_list = np.array(get_images(FLAGS.training_data_path)) | |
if not idx is None: | |
image_list = image_list[idx] | |
print('{} training images in {}'.format( | |
image_list.shape[0], FLAGS.training_data_path)) | |
index = np.arange(0, image_list.shape[0]) | |
epoch = 1 | |
while True: | |
np.random.shuffle(index) | |
images = [] | |
image_fns = [] | |
score_maps = [] | |
geo_maps = [] | |
overly_small_text_region_training_masks = [] | |
text_region_boundary_training_masks = [] | |
for i in index: | |
try: | |
im_fn = image_list[i] | |
im = cv2.imread(im_fn) | |
h, w, _ = im.shape | |
txt_fn = get_text_file(im_fn) | |
if not os.path.exists(txt_fn): | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print('text file {} does not exists'.format(txt_fn)) | |
continue | |
text_polys, text_tags = load_annotation(txt_fn) | |
text_polys, text_tags = check_and_validate_polys(FLAGS, text_polys, text_tags, (h, w)) | |
# random scale this image | |
rd_scale = np.random.choice(random_scale) | |
x_scale_variation = np.random.randint(-10, 10) / 100. | |
y_scale_variation = np.random.randint(-10, 10) / 100. | |
im = cv2.resize(im, dsize=None, fx=rd_scale + x_scale_variation, fy=rd_scale + y_scale_variation) | |
text_polys[:, :, 0] *= rd_scale + x_scale_variation | |
text_polys[:, :, 1] *= rd_scale + y_scale_variation | |
# random crop a area from image | |
if np.random.rand() < background_ratio: | |
# crop background | |
im, text_polys, text_tags = crop_area(FLAGS, im, text_polys, text_tags, crop_background=True) | |
if text_polys.shape[0] > 0: | |
# cannot find background | |
continue | |
# pad and resize image | |
im, _, _ = pad_image(im, FLAGS.input_size, is_train) | |
im = cv2.resize(im, dsize=(input_size, input_size)) | |
score_map = np.zeros((input_size, input_size), dtype=np.uint8) | |
geo_map_channels = 5 if FLAGS.geometry == 'RBOX' else 8 | |
geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32) | |
overly_small_text_region_training_mask = np.ones((input_size, input_size), dtype=np.uint8) | |
text_region_boundary_training_mask = np.ones((input_size, input_size), dtype=np.uint8) | |
else: | |
im, text_polys, text_tags = crop_area(FLAGS, im, text_polys, text_tags, crop_background=False) | |
if text_polys.shape[0] == 0: | |
continue | |
h, w, _ = im.shape | |
im, shift_h, shift_w = pad_image(im, FLAGS.input_size, is_train) | |
im, text_polys = resize_image(im, text_polys, FLAGS.input_size, shift_h, shift_w) | |
new_h, new_w, _ = im.shape | |
score_map, geo_map, overly_small_text_region_training_mask, text_region_boundary_training_mask = generate_rbox(FLAGS, (new_h, new_w), text_polys, text_tags) | |
if vis: | |
fig, axs = plt.subplots(3, 2, figsize=(20, 30)) | |
axs[0, 0].imshow(im[:, :, ::-1]) | |
axs[0, 0].set_xticks([]) | |
axs[0, 0].set_yticks([]) | |
for poly in text_polys: | |
poly_h = min(abs(poly[3, 1] - poly[0, 1]), abs(poly[2, 1] - poly[1, 1])) | |
poly_w = min(abs(poly[1, 0] - poly[0, 0]), abs(poly[2, 0] - poly[3, 0])) | |
axs[0, 0].add_artist(Patches.Polygon( | |
poly, facecolor='none', edgecolor='green', linewidth=2, linestyle='-', fill=True)) | |
axs[0, 0].text(poly[0, 0], poly[0, 1], '{:.0f}-{:.0f}'.format(poly_h, poly_w), color='purple') | |
axs[0, 1].imshow(score_map[::, ::]) | |
axs[0, 1].set_xticks([]) | |
axs[0, 1].set_yticks([]) | |
axs[1, 0].imshow(geo_map[::, ::, 0]) | |
axs[1, 0].set_xticks([]) | |
axs[1, 0].set_yticks([]) | |
axs[1, 1].imshow(geo_map[::, ::, 1]) | |
axs[1, 1].set_xticks([]) | |
axs[1, 1].set_yticks([]) | |
axs[2, 0].imshow(geo_map[::, ::, 2]) | |
axs[2, 0].set_xticks([]) | |
axs[2, 0].set_yticks([]) | |
axs[2, 1].imshow(training_mask[::, ::]) | |
axs[2, 1].set_xticks([]) | |
axs[2, 1].set_yticks([]) | |
plt.tight_layout() | |
plt.show() | |
plt.close() | |
im = (im / 127.5) - 1. | |
images.append(im[:, :, ::-1].astype(np.float32)) | |
image_fns.append(im_fn) | |
score_maps.append(score_map[::4, ::4, np.newaxis].astype(np.float32)) | |
geo_maps.append(geo_map[::4, ::4, :].astype(np.float32)) | |
overly_small_text_region_training_masks.append(overly_small_text_region_training_mask[::4, ::4, np.newaxis].astype(np.float32)) | |
text_region_boundary_training_masks.append(text_region_boundary_training_mask[::4, ::4, np.newaxis].astype(np.float32)) | |
if len(images) == FLAGS.batch_size: | |
yield [np.array(images), np.array(overly_small_text_region_training_masks), np.array(text_region_boundary_training_masks), np.array(score_maps)], [np.array(score_maps), np.array(geo_maps)] | |
images = [] | |
image_fns = [] | |
score_maps = [] | |
geo_maps = [] | |
overly_small_text_region_training_masks = [] | |
text_region_boundary_training_masks = [] | |
except Exception as e: | |
import traceback | |
if not FLAGS.suppress_warnings_and_error_messages: | |
traceback.print_exc() | |
continue | |
epoch += 1 | |
@threadsafe_generator | |
def val_generator(FLAGS, idx=None, is_train=False): | |
image_list = np.array(get_images(FLAGS.validation_data_path)) | |
if not idx is None: | |
image_list = image_list[idx] | |
print('{} validation images in {}'.format( | |
image_list.shape[0], FLAGS.training_data_path)) | |
index = np.arange(0, image_list.shape[0]) | |
epoch = 1 | |
while True: | |
np.random.shuffle(index) | |
images = [] | |
image_fns = [] | |
score_maps = [] | |
geo_maps = [] | |
overly_small_text_region_training_masks = [] | |
text_region_boundary_training_masks = [] | |
for i in index: | |
try: | |
im_fn = image_list[i] | |
im = cv2.imread(im_fn) | |
h, w, _ = im.shape | |
txt_fn = get_text_file(im_fn) | |
if not os.path.exists(txt_fn): | |
if not FLAGS.suppress_warnings_and_error_messages: | |
print('text file {} does not exists'.format(txt_fn)) | |
continue | |
text_polys, text_tags = load_annotation(txt_fn) | |
text_polys, text_tags = check_and_validate_polys(FLAGS, text_polys, text_tags, (h, w)) | |
im, shift_h, shift_w = pad_image(im, FLAGS.input_size, is_train) | |
im, text_polys = resize_image(im, text_polys, FLAGS.input_size, shift_h, shift_w) | |
new_h, new_w, _ = im.shape | |
score_map, geo_map, overly_small_text_region_training_mask, text_region_boundary_training_mask = generate_rbox(FLAGS, (new_h, new_w), text_polys, text_tags) | |
im = (im / 127.5) - 1. | |
images.append(im[:, :, ::-1].astype(np.float32)) | |
image_fns.append(im_fn) | |
score_maps.append(score_map[::4, ::4, np.newaxis].astype(np.float32)) | |
geo_maps.append(geo_map[::4, ::4, :].astype(np.float32)) | |
overly_small_text_region_training_masks.append(overly_small_text_region_training_mask[::4, ::4, np.newaxis].astype(np.float32)) | |
text_region_boundary_training_masks.append(text_region_boundary_training_mask[::4, ::4, np.newaxis].astype(np.float32)) | |
if len(images) == FLAGS.batch_size: | |
yield [np.array(images), np.array(overly_small_text_region_training_masks), np.array(text_region_boundary_training_masks), np.array(score_maps)], [np.array(score_maps), np.array(geo_maps)] | |
images = [] | |
image_fns = [] | |
score_maps = [] | |
geo_maps = [] | |
overly_small_text_region_training_masks = [] | |
text_region_boundary_training_masks = [] | |
except Exception as e: | |
import traceback | |
if not FLAGS.suppress_warnings_and_error_messages: | |
traceback.print_exc() | |
continue | |
epoch += 1 | |
def count_samples(FLAGS): | |
if sys.version_info >= (3, 0): | |
return len([f for f in next(os.walk(FLAGS.training_data_path))[2] if f[-4:] == ".jpg"]) | |
else: | |
return len([f for f in os.walk(FLAGS.training_data_path).next()[2] if f[-4:] == ".jpg"]) | |
def load_data_process(args): | |
(image_file, FLAGS, is_train) = args | |
try: | |
img = cv2.imread(image_file) | |
h, w, _ = img.shape | |
txt_file = get_text_file(image_file) | |
if not os.path.exists(txt_file): | |
print('text file {} does not exists'.format(txt_file)) | |
text_polys, text_tags = load_annotation(txt_file) | |
text_polys, text_tags = check_and_validate_polys(FLAGS, text_polys, text_tags, (h, w)) | |
img, shift_h, shift_w = pad_image(img, FLAGS.input_size, is_train=is_train) | |
img, text_polys = resize_image(img, text_polys, FLAGS.input_size, shift_h, shift_w) | |
new_h, new_w, _ = img.shape | |
score_map, geo_map, overly_small_text_region_training_mask, text_region_boundary_training_mask = generate_rbox(FLAGS, (new_h, new_w), text_polys, text_tags) | |
img = (img / 127.5) - 1. | |
return img[:, :, ::-1].astype(np.float32), image_file, score_map[::4, ::4, np.newaxis].astype(np.float32), geo_map[::4, ::4, :].astype(np.float32), overly_small_text_region_training_mask[::4, ::4, np.newaxis].astype(np.float32), text_region_boundary_training_mask[::4, ::4, np.newaxis].astype(np.float32) | |
except Exception as e: | |
import traceback | |
if not FLAGS.suppress_warnings_and_error_messages: | |
traceback.print_exc() | |
def load_data(FLAGS, is_train=False): | |
image_files = np.array(get_images(FLAGS.validation_data_path)) | |
images = [] | |
image_fns = [] | |
score_maps = [] | |
geo_maps = [] | |
overly_small_text_region_training_masks = [] | |
text_region_boundary_training_masks = [] | |
pool = Pool(FLAGS.nb_workers) | |
if sys.version_info >= (3, 0): | |
loaded_data = pool.map_async(load_data_process, zip(image_files, itertools.repeat(FLAGS), itertools.repeat(is_train))).get(9999999) | |
else: | |
loaded_data = pool.map_async(load_data_process, itertools.izip(image_files, itertools.repeat(FLAGS), itertools.repeat(is_train))).get(9999999) | |
pool.close() | |
pool.join() | |
images = [item[0] for item in loaded_data if not item is None] | |
image_fns = [item[1] for item in loaded_data if not item is None] | |
score_maps = [item[2] for item in loaded_data if not item is None] | |
geo_maps = [item[3] for item in loaded_data if not item is None] | |
overly_small_text_region_training_masks = [item[4] for item in loaded_data if not item is None] | |
text_region_boundary_training_masks = [item[5] for item in loaded_data if not item is None] | |
print('Number of validation images : %d' % len(images)) | |
return np.array(images), np.array(overly_small_text_region_training_masks), np.array(text_region_boundary_training_masks), np.array(score_maps), np.array(geo_maps) | |
if __name__ == '__main__': | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment