Skip to content

Instantly share code, notes, and snippets.

@ppq200
Created May 23, 2018 12:10
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ppq200/02c61b79e227d5395b9717160bb09f31 to your computer and use it in GitHub Desktop.
Save ppq200/02c61b79e227d5395b9717160bb09f31 to your computer and use it in GitHub Desktop.
yolo使って車載動画からバイクとのすれ違いを抜き出す
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# yolov3を使ってバイク車載動画からバイクすれ違い動画を作成する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# darknet周りのimport\n",
"\n",
"import cv2 \n",
"from scipy.misc import imread\n",
"\n",
"import sys, os\n",
"sys.path.append(os.path.join(os.getcwd(),'python/'))\n",
"\n",
"# darknetのインストール・YOLOv3のDL https://pjreddie.com/darknet/yolo/\n",
"import darknet as dn # 2to3.py -w python/darknet.py でpython3対応にする\n",
"# OSError: libdarknet.so: cannot open shared object file: No such file or directory には\n",
"# python/darknet.pyの48行目あたりのlibdarknet.soをフルパスで指定する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"from IPython import display\n",
"\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import skvideo.io\n",
"import skimage.io"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly\n",
"plotly.offline.init_notebook_mode(connected=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"#%matplotlib notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def array_to_image(arr):\n",
" arr = arr.transpose(2,0,1)\n",
" c = arr.shape[0]\n",
" h = arr.shape[1]\n",
" w = arr.shape[2]\n",
" arr = (arr/255.0).flatten()\n",
" data = dn.c_array(dn.c_float, arr)\n",
" im = dn.IMAGE(w,h,c,data)\n",
" return im"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# darknet weightsのload\n",
"\n",
"# 文字列の前にはbをつける必要あり\n",
"# https://github.com/pjreddie/darknet/issues/241 \n",
"net = dn.load_net(b\"cfg/yolov3.cfg\", b\"yolov3.weights\", 0)\n",
"meta = dn.load_meta(b\"cfg/coco.data\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# お試し実行 ( jpgを使う )\n",
"r = dn.detect(net, meta, b\"data/dog.jpg\")\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ./python/darknet.py に 以下のようなdetect2関数を追加\n",
"# detect関数で利用しているloadimageを利用しないもの\n",
"#def detect2(net, meta, im, thresh=.5, hier_thresh=.5, nms=.45):\n",
"# num = c_int(0)\n",
"# pnum = pointer(num)\n",
"# predict_image(net, im)\n",
"# dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)\n",
"# num = pnum[0]\n",
"# if (nms): do_nms_obj(dets, num, meta.classes, nms);\n",
"#\n",
"# res = []\n",
"# for j in range(num):\n",
"# for i in range(meta.classes):\n",
"# if dets[j].prob[i] > 0:\n",
"# b = dets[j].bbox\n",
"# res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))\n",
"# res = sorted(res, key=lambda x: -x[1])\n",
"# free_detections(dets, num)\n",
"# return res\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# お試し実行 (imageio + array_to_imageを使う)\n",
"import imageio\n",
"arr= imageio.imread('data/dog.jpg')\n",
"\n",
"im = array_to_image(arr)\n",
"r = dn.detect2(net, meta, im) # 追加したdetect2関数を利用\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 認識した動画を保持\n",
"def draw_bb(r, frame):\n",
" print(r)\n",
" for obj in r:\n",
"\n",
" label = obj[0]\n",
" preb = obj[1]\n",
" x = int(obj[2][0])\n",
" y = int(obj[2][1])\n",
" w = int(obj[2][2] / 2)\n",
" h = int(obj[2][3] / 2)\n",
"\n",
" # バイクだけ描画\n",
" if label != b'motorbike':\n",
" continue\n",
" \n",
" text = \"{} {:.2f}\".format(label.decode(), preb)\n",
" cv2.rectangle(frame, ( x - w , y - h ), ( (x + w ) , (y + h )), (255, 0, 0), 3)\n",
" cv2.putText(frame, text, ( x - w, y - h ), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)\n",
"\n",
" plt.imshow(frame) \n",
" plt.show()\n",
" \n",
" return frame"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 読み込み & 解析\n",
"\n",
"# 動画のfps\n",
"fps = 60\n",
"\n",
"# 処理の際に飛ばすフレームの数 \n",
"skip_frame_count = int(fps / 2) # 0.5秒ごと\n",
"\n",
"\n",
"path = \"~/Downloads/16320004.MOV\"\n",
"video_filepath = os.path.expanduser(path) \n",
"\n",
"name, _ext = os.path.splitext(os.path.basename(video_filepath))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 開始時刻\n",
"print(datetime.now().strftime(\"%Y/%m/%d %H:%M:%S\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cap = skvideo.io.vreader(video_filepath)\n",
"\n",
"result = []\n",
"\n",
"for i, frame in enumerate(cap):\n",
" # 全フレーム処理すると時間かかるので飛ばしつつ処理\n",
" if i % skip_frame_count != 0: \n",
" continue\n",
" # 10回処理する(うまく動作するかの確認用)\n",
" #if len(result) >= 10:\n",
" # break\n",
" \n",
" # 処理対象のフレームを表示\n",
" display.clear_output(wait=True)\n",
" print(\"frame:{}\".format(i))\n",
" print(datetime.now().strftime(\"%Y/%m/%d %H:%M:%S\"))\n",
" # plt.imshow(frame) \n",
" # plt.show()\n",
" \n",
" \n",
" # 時間計測用\n",
" start = time.time()\n",
" \n",
" # 物体認識\n",
" im = array_to_image(frame)\n",
" # しきい値を0.7にしてる\n",
" r = dn.detect2(net, meta, im, 0.7)\n",
" \n",
" # 1frame処理するのにかかった時間を表示\n",
" elapsed_time = time.time() - start\n",
" print(\"time:{} sec\".format(elapsed_time))\n",
" \n",
" # 存在する物体の表示\n",
" print(r)\n",
" # 物体の表示\n",
" draw_bb(r, frame)\n",
"\n",
" # デバッグ用にフレームをjpg保存\n",
" output_path = os.path.join(\"frame\", \"{}_frame{}.jpg\".format(name, i))\n",
" skimage.io.imsave(output_path, frame)\n",
" \n",
" # 認識結果を保存\n",
" res = []\n",
" for obj in r:\n",
" # とりあえずバイクだけを対象に保存(TODO: 人の大きさも使えるのでは? 駐車場などバイクの場所も見つけてしまう)\n",
" if obj[0] == b'motorbike':\n",
" # TODO: サイズ(面積) * 確率 を取得するほうがよいかも \n",
" res.append(obj)\n",
" result.append(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 終了時刻\n",
"print(datetime.now().strftime(\"%Y/%m/%d %H:%M:%S\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 解析結果のグラフ表示\n",
"\n",
"# 同時に出現するバイクの数の最大数を取得\n",
"max_bike_size = max([len(b) for b in result])\n",
"\n",
"# フレームごとのバイク存在確率を取得\n",
"# フレームごとに存在するバイクの数が異なるので最大数でpad\n",
"zero_pad_result = [np.pad([ bike[1] for bike in bikes] , [0, max_bike_size - len(bikes)], 'constant' ) for bikes in result]\n",
"zero_pad_result = np.array(zero_pad_result)\n",
"# zero_pad_result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# フレーム\n",
"x = list(range(0, len(zero_pad_result) * skip_frame_count, skip_frame_count))\n",
"\n",
"# 登場バイクごとの確率を表示\n",
"data = [ plotly.graph_objs.Scatter(x=x, y=b) for b in zero_pad_result.T ] \n",
"layout = plotly.graph_objs.Layout()\n",
"fig = plotly.graph_objs.Figure(data=data, layout=layout)\n",
"plotly.offline.iplot(fig)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# フレームごとの存在確率の総和\n",
"data = [plotly.graph_objs.Scatter(x=x, y=[np.sum(p) for (p) in zero_pad_result ] )]\n",
"layout = plotly.graph_objs.Layout()\n",
"fig = plotly.graph_objs.Figure(data=data, layout=layout)\n",
"plotly.offline.iplot(fig)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 書き込み(編集) バイクのすれ違いフレームを抜き出し\n",
"cap = skvideo.io.vreader(video_filepath)\n",
"\n",
"writer = skvideo.io.FFmpegWriter(\"{}_outputvideo.mp4\".format(name), \n",
" inputdict={\n",
" \"-r\": str(fps)\n",
" })\n",
"\n",
"# 抜き出すフレームのしきい値\n",
"threshold = 1.8\n",
"\n",
"for i, frame in enumerate(cap):\n",
" # フレームにおけるバイクの存在確率のsumを取得\n",
" sum_prob = np.sum(zero_pad_result[ int(i / skip_frame_count) ] )\n",
" # threshold台 見つけたら # TODO: 前後に伸ばす、認識ミスに対するノイズ除去(スムージング)\n",
" if sum_prob > threshold:\n",
" #print(i, i / FRAME_TIMES, sum_prob, result[ int(i / FRAME_TIMES) ])\n",
" writer.writeFrame(frame) # 書き込み\n",
"writer.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment