Skip to content

Instantly share code, notes, and snippets.

@tubackkhoa
Last active May 11, 2020 01:28
Show Gist options
  • Save tubackkhoa/4d74813b4aea7b15e9de13fe07c0a1e1 to your computer and use it in GitHub Desktop.
Save tubackkhoa/4d74813b4aea7b15e9de13fe07c0a1e1 to your computer and use it in GitHub Desktop.
ctpn
# coding=utf-8
from utils.text_connector.detectors import TextDetector
from utils.rpn_msr.proposal_layer import proposal_layer
from nets import model_train as model
import os
import sys
import time
import cv2
import numpy as np
import tensorflow as tf
sys.path.append(os.getcwd())
tf.app.flags.DEFINE_string('input', 'data/demo/', '')
tf.app.flags.DEFINE_string('gpu', '0', '')
tf.app.flags.DEFINE_string('mode', 'H', '')
tf.app.flags.DEFINE_bool('remove_collapsed', True, '')
tf.app.flags.DEFINE_bool('print_scrore', False, '')
tf.app.flags.DEFINE_float('min_good_ratio', 2.0, '')
tf.app.flags.DEFINE_integer('max_good_angle', 15, '')
tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints_mlt/', '')
FLAGS = tf.app.flags.FLAGS
input_image = tf.placeholder(
tf.float32, shape=[None, None, None, 3], name='input_image')
input_im_info = tf.placeholder(
tf.float32, shape=[None, 3], name='input_im_info')
bbox_pred, cls_pred, cls_prob = model.model(input_image)
exts = ['.jpg', '.png', '.jpeg', '.JPG']
def get_images(input):
files = []
_, ext = os.path.splitext(input)
if ext in exts:
files.append(input)
else:
for parent, dirnames, filenames in os.walk(input):
for filename in filenames:
for ext in exts:
if filename.endswith(ext):
files.append(os.path.join(parent, filename))
break
print("Find {} images".format(len(files)))
return files
def resize_image(img, max_size=1200, round=True):
h, w, _ = img.shape
im_size_max = max(h, w)
im_scale = float(max_size) / float(im_size_max)
new_h = int(h * im_scale)
new_w = int(w * im_scale)
# round to 16
if round:
new_h = new_h if new_h // 16 == 0 else (new_h // 16 + 1) * 16
new_w = new_w if new_w // 16 == 0 else (new_w // 16 + 1) * 16
re_im = cv2.resize(img, (new_w, new_h),
interpolation=cv2.INTER_AREA)
return re_im, (new_h / h, new_w / w)
def get_sess_and_detector():
with tf.get_default_graph().as_default():
global_step = tf.get_variable(
'global_step', [], initializer=tf.constant_initializer(0), trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)
saver = tf.train.Saver(variable_averages.variables_to_restore())
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
ckpt_state = tf.train.get_checkpoint_state(
FLAGS.checkpoint_path)
model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(
ckpt_state.model_checkpoint_path))
print('Restore from {}'.format(model_path))
saver.restore(sess, model_path)
textdetector = TextDetector(DETECT_MODE=FLAGS.mode)
return (sess, textdetector)
def main(argv=None):
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
sess, textdetector = get_sess_and_detector()
img_list = get_images(FLAGS.input)
for im_fn in img_list:
print('===============')
print(im_fn)
start = time.time()
try:
im = cv2.imread(im_fn)[:, :, ::-1]
except:
print("Error reading image {}!".format(im_fn))
continue
img, (rh, rw) = resize_image(im)
input_img = cv2.GaussianBlur(img, (7, 7), 0)
# input_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR)
h, w, c = img.shape
im_info = np.array([h, w, c]).reshape([1, 3])
bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
feed_dict={input_image: [input_img],
input_im_info: im_info})
textsegs, _ = proposal_layer(
cls_prob_val, bbox_pred_val, im_info)
scores = textsegs[:, 0]
textsegs = textsegs[:, 1:5]
boxes = textdetector.detect(
textsegs, scores[:, np.newaxis], img.shape[:2])
boxes = np.array(boxes, dtype=np.int)
cost_time = (time.time() - start)
print("cost time: {:.2f}s".format(cost_time))
for i, box in enumerate(boxes):
cnt = box[:8].astype(np.int32).reshape((-1, 1, 2))
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
collapsed = False
for px, py in box:
if px < 0 or py < 0 or px > img.shape[1] or py > img.shape[0]:
collapsed = True
break
(x, y), (w, h), a = rect
if h > w:
w, h = h, w
a = 90 + a
a = abs(a)
ratio = w / h
if (FLAGS.remove_collapsed and collapsed) or ratio < FLAGS.min_good_ratio or a > FLAGS.max_good_angle:
continue
cv2.polylines(img, [cnt], True, color=(0, 255, 0),
thickness=2, lineType=cv2.LINE_AA)
if FLAGS.print_scrore:
cv2.putText(img, str(
scores[i]), (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
img = cv2.resize(img, None, None, fx=1.0 / rh,
fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
dx = 600 / img.shape[0]
img = cv2.resize(img, None, None, fx=dx,
fy=dx, interpolation=cv2.INTER_LINEAR)
cv2.imshow(os.path.basename(im_fn), img[:, :, ::-1])
cv2.waitKey(0)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment