Last active
May 11, 2020 01:28
-
-
Save tubackkhoa/4d74813b4aea7b15e9de13fe07c0a1e1 to your computer and use it in GitHub Desktop.
ctpn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from utils.text_connector.detectors import TextDetector | |
from utils.rpn_msr.proposal_layer import proposal_layer | |
from nets import model_train as model | |
import os | |
import sys | |
import time | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
sys.path.append(os.getcwd()) | |
tf.app.flags.DEFINE_string('input', 'data/demo/', '') | |
tf.app.flags.DEFINE_string('gpu', '0', '') | |
tf.app.flags.DEFINE_string('mode', 'H', '') | |
tf.app.flags.DEFINE_bool('remove_collapsed', True, '') | |
tf.app.flags.DEFINE_bool('print_scrore', False, '') | |
tf.app.flags.DEFINE_float('min_good_ratio', 2.0, '') | |
tf.app.flags.DEFINE_integer('max_good_angle', 15, '') | |
tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints_mlt/', '') | |
FLAGS = tf.app.flags.FLAGS | |
input_image = tf.placeholder( | |
tf.float32, shape=[None, None, None, 3], name='input_image') | |
input_im_info = tf.placeholder( | |
tf.float32, shape=[None, 3], name='input_im_info') | |
bbox_pred, cls_pred, cls_prob = model.model(input_image) | |
exts = ['.jpg', '.png', '.jpeg', '.JPG'] | |
def get_images(input): | |
files = [] | |
_, ext = os.path.splitext(input) | |
if ext in exts: | |
files.append(input) | |
else: | |
for parent, dirnames, filenames in os.walk(input): | |
for filename in filenames: | |
for ext in exts: | |
if filename.endswith(ext): | |
files.append(os.path.join(parent, filename)) | |
break | |
print("Find {} images".format(len(files))) | |
return files | |
def resize_image(img, max_size=1200, round=True): | |
h, w, _ = img.shape | |
im_size_max = max(h, w) | |
im_scale = float(max_size) / float(im_size_max) | |
new_h = int(h * im_scale) | |
new_w = int(w * im_scale) | |
# round to 16 | |
if round: | |
new_h = new_h if new_h // 16 == 0 else (new_h // 16 + 1) * 16 | |
new_w = new_w if new_w // 16 == 0 else (new_w // 16 + 1) * 16 | |
re_im = cv2.resize(img, (new_w, new_h), | |
interpolation=cv2.INTER_AREA) | |
return re_im, (new_h / h, new_w / w) | |
def get_sess_and_detector(): | |
with tf.get_default_graph().as_default(): | |
global_step = tf.get_variable( | |
'global_step', [], initializer=tf.constant_initializer(0), trainable=False) | |
variable_averages = tf.train.ExponentialMovingAverage( | |
0.997, global_step) | |
saver = tf.train.Saver(variable_averages.variables_to_restore()) | |
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) | |
ckpt_state = tf.train.get_checkpoint_state( | |
FLAGS.checkpoint_path) | |
model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename( | |
ckpt_state.model_checkpoint_path)) | |
print('Restore from {}'.format(model_path)) | |
saver.restore(sess, model_path) | |
textdetector = TextDetector(DETECT_MODE=FLAGS.mode) | |
return (sess, textdetector) | |
def main(argv=None): | |
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu | |
sess, textdetector = get_sess_and_detector() | |
img_list = get_images(FLAGS.input) | |
for im_fn in img_list: | |
print('===============') | |
print(im_fn) | |
start = time.time() | |
try: | |
im = cv2.imread(im_fn)[:, :, ::-1] | |
except: | |
print("Error reading image {}!".format(im_fn)) | |
continue | |
img, (rh, rw) = resize_image(im) | |
input_img = cv2.GaussianBlur(img, (7, 7), 0) | |
# input_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
# input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR) | |
h, w, c = img.shape | |
im_info = np.array([h, w, c]).reshape([1, 3]) | |
bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob], | |
feed_dict={input_image: [input_img], | |
input_im_info: im_info}) | |
textsegs, _ = proposal_layer( | |
cls_prob_val, bbox_pred_val, im_info) | |
scores = textsegs[:, 0] | |
textsegs = textsegs[:, 1:5] | |
boxes = textdetector.detect( | |
textsegs, scores[:, np.newaxis], img.shape[:2]) | |
boxes = np.array(boxes, dtype=np.int) | |
cost_time = (time.time() - start) | |
print("cost time: {:.2f}s".format(cost_time)) | |
for i, box in enumerate(boxes): | |
cnt = box[:8].astype(np.int32).reshape((-1, 1, 2)) | |
rect = cv2.minAreaRect(cnt) | |
box = cv2.boxPoints(rect) | |
collapsed = False | |
for px, py in box: | |
if px < 0 or py < 0 or px > img.shape[1] or py > img.shape[0]: | |
collapsed = True | |
break | |
(x, y), (w, h), a = rect | |
if h > w: | |
w, h = h, w | |
a = 90 + a | |
a = abs(a) | |
ratio = w / h | |
if (FLAGS.remove_collapsed and collapsed) or ratio < FLAGS.min_good_ratio or a > FLAGS.max_good_angle: | |
continue | |
cv2.polylines(img, [cnt], True, color=(0, 255, 0), | |
thickness=2, lineType=cv2.LINE_AA) | |
if FLAGS.print_scrore: | |
cv2.putText(img, str( | |
scores[i]), (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA) | |
img = cv2.resize(img, None, None, fx=1.0 / rh, | |
fy=1.0 / rw, interpolation=cv2.INTER_LINEAR) | |
dx = 600 / img.shape[0] | |
img = cv2.resize(img, None, None, fx=dx, | |
fy=dx, interpolation=cv2.INTER_LINEAR) | |
cv2.imshow(os.path.basename(im_fn), img[:, :, ::-1]) | |
cv2.waitKey(0) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment