Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import training # 自前の学習器
import chainer.serializers
import chainer.functions as F
import argparse
import numpy as np
from PIL import Image
import os
from bottle import route, run
from bottle import get, post, put, request
from io import BytesIO
import chainer
from chainercv.datasets import voc_detection_label_names
from chainercv.links import SSD300
from chainercv.links import SSD512
from chainercv import utils
from chainercv.visualizations import vis_bbox
import chainercv.utils.bbox
import cv2
model = training.MLP()
rcnn_model = SSD300(
n_fg_class=len(voc_detection_label_names),
pretrained_model='voc0712')
chainer.cuda.get_device(0).use()
rcnn_model.to_gpu()
rcnn_model.use_preset('evaluate')
def detect_cardtype(image):
# NNへ入力可能な形式に変換する
pixels = np.asarray(image).astype(np.float32)
pixels = pixels.transpose(2, 0, 1)
pixels = pixels.reshape((1,) + pixels.shape)
print("NN input image shape = {}".format(pixels.shape) )
# 識別
y = model(pixels)
prediction = F.softmax(y)
m = np.argmax(prediction.data)
return m
# カードの透視投影変換
def card_perspective(img, box):
# 識別するために透視投影変換をかける
# 右上から右回り
# 223x311
w = 223.
h = 311.
base_pts = np.array([
[w, 0.],
[w, h],
[0., h],
[0., 0.]
]).astype(np.float32)
print("box shape = {}".format(box.shape) )
print("basepts shape = {}".format(base_pts.shape) )
perspective_matrix = cv2.getPerspectiveTransform(box.astype(np.float32), base_pts)
warped_image = cv2.warpPerspective(img, perspective_matrix, (int(w), int(h) ) )
print("warped_image shape = {}".format(warped_image.shape) )
'''
imput = 223 x 311
cv::Point leftup_mergin(16+26, 36);
cv::Point rightbottom_mergin(16+26, 48);
img_transform(dst_img2, dst_img3,
cv::Rect(
leftup_mergin.x,
leftup_mergin.y,
dst_img3.cols- leftup_mergin.x - rightbottom_mergin.x,
dst_img3.cols - leftup_mergin.y - rightbottom_mergin.y),
crop_transform);
'''
result_image = warped_image[36:227, 42:227, :]
print("cropped_image shape = {}".format(result_image.shape) )
return result_image
# カードの検出器
def card_detect(img):
# 平滑化、二値化、輪郭抽出
g_img = cv2.GaussianBlur(img, (5,5), 8)
r, bin_img = cv2.threshold(g_img, 85, 255, cv2.THRESH_BINARY_INV)
canny_img = cv2.Canny(bin_img, 50, 200)
'''
# 標準ハフ変換だとこうなる
h, w, c = img.shape
hough_thresh = min(w,h) // 4
lines = cv2.HoughLines(canny_img, 1, np.pi/90, hough_thresh)
if lines is not None:
print("lines num = {}".format(len(lines) ) )
for rho,theta in lines[0]:
#print("line(rho, theta) = {}, {}".format(rho, theta) )
a = np.cos(theta)
b = np.sin(theta)
x0 = a*rho
y0 = b*rho
x1 = int(x0 + 1000*(-b))
y1 = int(y0 + 1000*(a))
x2 = int(x0 - 1000*(-b))
y2 = int(y0 - 1000*(a))
cv2.line(img,(x1,y1),(x2,y2),(0,0,255),1)
'''
# 確率的ハフ変換による線分検出
h, w, c = img.shape
minLineLength = min(w,h) // 4
maxLineGap = 10
lines = cv2.HoughLinesP(canny_img, 1, np.pi/180, 40, minLineLength, maxLineGap)
if lines is not None:
if len(lines) < 4:
return None
# 検出した線分のすべての頂点を列挙
pts = np.array(lines)
pts = pts.reshape(-1, 2)
# 回転を考慮した外接矩形を求める
rect = cv2.minAreaRect(pts)
box = cv2.boxPoints(rect)
box = np.int0(box)
return box
'''
y = model(pixels)
prediction = F.softmax(y)
m = np.argmax(prediction.data)
print ("detect card num = {}".format(m) )
'''
# 外接矩形と凸包を描画
img = cv2.drawContours(img,[box],0,(0,255,0),2)
# 凸包を求める
hull = cv2.convexHull(pts)
print("hull = {}".format(hull) )
# 求めた凸包もしくは外接矩形を描画
img = cv2.drawContours(img,[hull],0,(255,0,255),2)
print("pts = {}".format(pts) )
p0 = pts[0,:]
p1 = pts[1,:]
p2 = pts[2,:]
p3 = pts[3,:]
v0 = p0 - p1
v1 = p2 - p3
v2 = p3 - p1
v3 = p0 - p3
area0 = np.cross(v1, v2)
area1 = np.cross(v1, v3)
total_area = area0 + area1
if np.abs(total_area) >= 1:
print("total area = {}".format(total_area) )
ratio = area0 / total_area
cross_pos = p1 + (v0*ratio)
p = np.int0(cross_pos)
print("cross pos = {}".format(p) )
(px, py) = p
cv2.circle(img, (px,py), 50, (255, 255, 0), 3 )
return None
# routeデコレーター
# これを使用してURLのPathと関数をマッピングする。
@route('/hello')
def hello():
return "Hello World!"
@put('/resource')
def put_resource():
# リクエストのボディをとる
data = request.body.read()
print("received data len = {}".format(len(data) ) )
f = open("test.png", "wb")
f.write(data)
f.close()
# リクエストのボディに記述されている画像データをRAM上に展開,デコードする
output = BytesIO(data)
image = Image.open( output )
# NNへ入力可能な形式に変換する
pixels = np.asarray(image).astype(np.float32)
pixels = pixels.transpose(2, 0, 1)
pixels = pixels.reshape((1,) + pixels.shape)
# 識別
print("NN input image shape = {}".format(pixels.shape) )
y = model(pixels)
prediction = F.softmax(y)
m = np.argmax(prediction.data)
print ("detect card num = {}".format(m) )
return ("ok,%d" % m)
@put('/resource2')
def put_resource2():
data = request.body.read()
print("received data len = {}".format(len(data) ) )
# ファイル保存して正しいファイルを受信できていることを確認する
f = open("test.png", "wb")
f.write(data)
f.close()
# リクエストのボディに記述されている画像データをRAM上に展開,デコードする
output = BytesIO(data)
raw_image = Image.open( output )
# NNへ入力可能な形式に変換する
raw_pixels = np.asarray(raw_image).astype(np.float32)
pixels = raw_pixels.transpose(2, 0, 1)
# R-CNNへ入力してobject-proposalをする
bboxes, labels, scores = rcnn_model.predict([pixels])
bbox, label, score = bboxes[0], labels[0], scores[0]
#vis_bbox(
# pixels, bbox, label, score, label_names=voc_detection_label_names)
#print("imgshape = {}".format(pixels.shape) ) # (3, 300, 300)
#cv2.imwrite("test_imgfile.png", pixels)
# 提案されたオブジェクト領域に分解する
# opencvで扱える形式の画像を生成する
narray=np.fromstring(data, dtype='uint8')
base_img = cv2.imdecode(narray, -1)
cv2.imwrite("cvtest.png", base_img)
print("imgshape = {}".format(base_img.shape) ) # (300,300,3)
print("base image shape = {}".format(base_img.shape) )
for i, b in enumerate(bbox):
filename = "./tmp/crop" + str(i) + ".jpg"
print("boundingbox{} = {}".format( filename, b ) )
top, left, bottom, right = b
t = 0 if top < 0 else top
l = 0 if left < 0 else left
croped_img = base_img[int(t):int(bottom), int(l):int(right)]
croped_img_clone = np.array(croped_img)
bbox = card_detect(croped_img_clone)
if bbox is not None:
card_image = card_perspective(croped_img_clone, bbox)
print("card img shape = {}".format(card_image.shape) )
cv2.imwrite(filename, card_image)
pil_raw = card_image[::-1, :, ::-1].copy()
pil_img = Image.fromarray(pil_raw)
card_image_label = detect_cardtype(pil_img)
print("{}".format( card_image_label ) )
return ("ok %d objects" % len(label) )
def main():
# 学習済みモデルを利用する準備
global model
chainer.serializers.load_npz('result/mymodel.npz', model)
# ビルトインの開発用サーバーの起動
# ここでは、debugとreloaderを有効にしている
# run(host='localhost', port=8080, debug=True, reloader=True)
run(host='0.0.0.0', port=8080, debug=True, reloader=True)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment