Created
January 6, 2019 11:15
-
-
Save Pked01/8115f6c1e3c4c625ae7d39c548e00407 to your computer and use it in GitHub Desktop.
ABG/vision_related/libs/image-labelling-tool/examples/ssd/fasterrcnn_training_chainer.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:16:25.678729Z", | |
"end_time": "2019-01-06T09:16:33.848508Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import argparse\nimport matplotlib.pyplot as plot\nimport yaml,pickle\n\nimport chainer,cv2,time,datetime,os\n\n#from chainercv.datasets import voc_detection_label_names\nfrom chainercv.links import SSD300\nfrom chainercv.links import FasterRCNNVGG16\nfrom chainercv import utils\nfrom chainercv.visualizations import vis_bbox,vis_image\n\nfrom original_detection_dataset import OriginalDetectionDataset\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom chainer import serializers\nimport IPython.display as Disp\n\nimport pandas as pd\n\nimport _pickle as cpickle", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### loading model" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "- SSD" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-12-06T08:23:20.749435Z", | |
"start_time": "2018-12-06T08:23:15.887270Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "label_names = ('helmet','no_helmet','vest','no_vest','person')\nmodel = SSD300(\n n_fg_class=len(label_names),\n pretrained_model='result_with_human/model_iter_74000')", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "- Faster-RCNN" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:35:53.784653Z", | |
"end_time": "2019-01-06T09:35:57.573046Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "label_names = ('person','helmet','no_helmet','vest','no_vest')\nmodel = FasterRCNNVGG16(\n n_fg_class=len(label_names),\n pretrained_model='model_faster_rcnn_new/model_iter_120000')", | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### setting gpu usage" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:35:57.574988Z", | |
"end_time": "2019-01-06T09:35:57.863316Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "chainer.cuda.get_device_from_id(0).use()\nmodel.to_gpu()", | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 23, | |
"data": { | |
"text/plain": "<chainercv.links.model.faster_rcnn.faster_rcnn_vgg.FasterRCNNVGG16 at 0x7fb5251e5240>" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### setting theshold" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:16:42.063695Z", | |
"end_time": "2019-01-06T09:16:42.066500Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "model.score_thresh =.6", | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### check intersection between two bbox" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:16:42.068909Z", | |
"end_time": "2019-01-06T09:16:42.176351Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def check_intersection(a, b):\n \"\"\"\n a : rectangle in (x,y,w,h)\n \"\"\"\n x = max(a[0], b[0])\n y = max(a[1], b[1])\n w = min(a[0]+a[2], b[0]+b[2]) - x\n h = min(a[1]+a[3], b[1]+b[3]) - y\n if w<0 or h<0: \n return False,() # or (0,0,0,0) ?\n return True,(x, y, w, h)", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### converting bbox format" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:16:42.178097Z", | |
"end_time": "2019-01-06T09:16:42.260653Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def convert_y1x1y2x2toxywh(rect):\n return [rect[1],rect[0],rect[3]-rect[1],rect[2]-rect[0]]\n ", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### functionality to create bbox" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:29:56.089845Z", | |
"end_time": "2019-01-06T09:29:56.167940Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def create_bbox_img(img,bboxes,labels,scores,lower_side=False,use_middle=False,use_person_filter=False):\n \n font = cv2.FONT_HERSHEY_SIMPLEX\n bboxes = bboxes[0].astype(int).tolist()\n orig_scores = scores\n scores = [np.array([round(i,1) for i in orig_scores[0]])]\n if use_person_filter:\n failed_idx = []\n # 1. check if person is three\n # 2. if there's no person there's shouldn't be any detection\n if 4 in labels[0]:#4 is index for person\n idx_person = np.where(labels[0]==4)[0].tolist()\n for i in range(len(bboxes)): # for each index of bbox\n for j in idx_person: # for each person detection\n if check_intersection(convert_y1x1y2x2toxywh(bboxes[i]),convert_y1x1y2x2toxywh(bboxes[j]))[0]:\n # if intersection is positive \n try:\n failed_idx.remove(i) # remove i from failed index\n except:\n pass\n break # breaking out of loop\n else:\n failed_idx.append(i)\n else: # if no person detected\n failed_idx = list(range(len(labels)))\n \n \n for idx,i in enumerate(bboxes):\n # if in failed index then no show\n if use_person_filter:\n if idx in failed_idx:\n continue # do not print on screen\n #bbox\n #('helmet', 'no_helmet', 'vest', 'no_vest','person')\n bbox_colors = [(0,0,0),(0,255,0),(0,0,255),(0,255,0),(0,0,255)]\n cv2.rectangle(img,(i[1],i[0]),(i[3],i[2]),bbox_colors[labels[0][idx]],2)\n ##text\n #print(i[0],i[1],type(i[0]))\n\n \n \n if lower_side:\n cv2.putText(img,str(label_names[labels[0][(idx)]]) +' : '+str(scores[0][(idx)]), (i[1],i[2]+20), font, 0.8, (0, 0, 255),1, cv2.LINE_AA)\n elif use_middle:\n cv2.putText(img,str(label_names[labels[0][(idx)]]) +' : '+str(scores[0][(idx)]), (i[1],int(.5*(i[0]+i[2]))), font, 0.8, (0, 0, 255),1, cv2.LINE_AA)\n else: \n cv2.putText(img,str(label_names[labels[0][(idx)]]) +' : '+str(scores[0][(idx)]), (i[1],i[0]-20), font, 0.8, (0, 0, 255),1, cv2.LINE_AA)\n \n return img", | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### logging function" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:16:43.187998Z", | |
"end_time": "2019-01-06T09:16:43.270026Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def logging_class(object):\n def __init__(direc_path = None,threshold_time = 5):\n \"\"\"\n direc_path : path where you want to log information \n threshold_time : time for detection to start(Seconds)\n \n \"\"\"\n \n os.environ['TZ'] = \"Asia/Calcutta\"\n if os.direc_path is None:\n self.cwd = os.getcwd()\n else:\n self.cwd = direc_path\n \n self.threshold_time = threshold_time\n os.makedirs(os.path.join(self.cwd,\"video_alerts\"),exist_ok=True)\n os.makedirs(os.path.join(self.cwd,\"text_alerts\"),exist_ok=True)\n \n self.vid_writer = None # initializing video writer object\n self.file_writer = None # initialize text writer\n self.vid_filename = None\n self.csv_file_name = None\n self.violation_timer\n \n \n def __write_video(self,flag = 'start',fps=20,img_size=(640,480)):\n \"\"\"\n flag : start,stop want to start writing video\n fps : fps rate of vide\n img_size : img size in width, height\n \n vid_writer object will be created when violation created and it will be release only if \n threshold time is over\n \"\"\"\n if flag=='start':\n if self.vid_writer is None:\n self.vid_filename = os.path.join(self.cwd,\"video_alerts\",datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\"))+'.mp4'\n self.vid_writer = cv2.VideoWriter(self.vid_filename,cv2.VideoWriter_fourcc(*'X264'), fps,img_size)\n if flag=='stop':\n self.vid_writer.release()\n self.vid_writer = None\n def __write_file(self,flag = 'start'):\n \"\"\"\n flag= start/stop\n \"\"\"\n if flag=='start':\n if self.file_writer is None:\n self.csv_file_name = os.path.join(self.cwd,\"text_alerts\",datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\"))+'.csv'\n #self.file_writer = open(datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\"),\"w+\") \n if flag=='stop':\n self.file_writer.close()\n self.file_writer = None\n def add_frames(self,image):\n if self.vid_writer is None:\n self.__write_video('start')\n self.vid_writer.write(image)\n def add_info2file(self,list_data):\n pass\n\n\n \n \n ", | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-10-12T09:10:38.057224Z", | |
"start_time": "2018-10-12T09:09:09.931933Z" | |
} | |
}, | |
"cell_type": "markdown", | |
"source": "#### write video" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:31:12.676627Z", | |
"end_time": "2019-01-06T09:31:17.264383Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#out = cv2.VideoWriter('test_vid/test_vid_2_op.mp4',cv2.VideoWriter_fourcc('M','J','P','G'), 20, (640,480))\nout = cv2.VideoWriter('/home/prateek/Downloads/Notebooks/output_videos/Whatsapp_plant_op.mp4',cv2.VideoWriter_fourcc(*'MJPG'), 20, (640,480))\n# for i in files:H264\n# try:\n# img = utils.read_image(path+'/'+i, color=True)\n# except Exception as e:\n# continue\n# print(e)\n\ncap = cv2.VideoCapture('WhatsApp Video 2018-10-21 at 9.51.59 AM.mp4')\n#cap.set(1,100)\nwhile True: \n try:\n ret,frame = cap.read()\n frame = cv2.rotate(frame,0)\n if not ret:\n break\n img_orig = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n img = np.rollaxis(img_orig, 2, 0) \n\n #img = utils.read_image(path+'/'+i, color=True)\n except Exception as e:\n print(e)\n cap.release()\n out.release()\n continue\n\n bboxes, labels, scores = model.predict([img])\n \n #img = cv2.imread(path+'/'+i)///''\n\n img_1 = create_bbox_img(frame,bboxes,labels,scores,lower_side=True,use_middle=False)\n \n# plt.imshow(img_orig)\n# plt.show()\n# print(len(bboxes[0]))\n# Disp.clear_output(wait=True)\n #img_1 = cv2.cvtColor(img_1,cv2.COLOR_RGB2BGR)\n img_1 = cv2.resize(img_1,(640,480))\n out.write(img_1)\n\nout.release()", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-12-05T10:20:43.507598Z", | |
"start_time": "2018-12-05T10:20:43.499413Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "out.release()", | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### display video from webcam\n-- SSD\n- average fps in cpu mode is .5 (i7 7th gen)\n- average fps in gpu mode is 28 (gtx 1060)\n\n-- Faster-RCNN\n\n- average FPS in gpu mode is 5.3" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:40:41.276209Z", | |
"end_time": "2019-01-06T09:40:41.281320Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "model.score_thresh =.8", | |
"execution_count": 32, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2019-01-06T09:41:21.593784Z", | |
"end_time": "2019-01-06T09:42:01.700950Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "cv2.namedWindow('preview',cv2.WINDOW_NORMAL)\ncap = cv2.VideoCapture('Video_Data/Video/4.Hot_mill_view/Camera6_spandan office_spandan office_20181219094507_20181219094549_3182231.mp4')\n#log_info = logging_class()\n#cap.set(1,100)\nfps = int(cap.get(5))\n#file_name = datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\")+'.csv'\nfile_name = 'Webcam Test/output.csv'\ndf = pd.DataFrame(columns=label_names)\ndf.to_csv(file_name) \nframe_no = 0\nop = [] # file_output\nvideo_stack = []\nwhile True:\n frame_no+=1\n t1 = time.time()\n try:\n # clearing buffer \n for i in range(6):\n ret,frame = cap.read()\n #frame = cv2.rotate(frame,0)\n if not ret:\n break\n img_orig = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n img = np.rollaxis(img_orig, 2, 0) \n\n #img = utils.read_image(path+'/'+i, color=True)\n except Exception as e:\n print(e)\n cap.release()\n continue\n\n bboxes, labels, scores = model.predict([img])\n \n ############ writing all data ###############\n if frame_no%(int(fps*.5))==0: # seconds(1) to reset\n df = pd.DataFrame(pd.DataFrame(op,columns=label_names).median()).T\n df.to_csv(file_name,mode='a',header=False)\n op = []\n row = [0]*len(label_names)\n for i in labels[0]:\n row[i]+=1\n op.append(row)\n ############ data writing logic #############\n \n #############################################\n \n #img = cv2.imread(path+'/'+i)///''\n\n img_1 = create_bbox_img(frame,bboxes,labels,scores,lower_side=False,use_middle=False,use_person_filter=False)\n \n# plt.imshow(img_orig)\n# plt.show()\n# print(len(bboxes[0]))\n# Disp.clear_output(wait=True)\n #img_1 = cv2.cvtColor(img_1,cv2.COLOR_RGB2BGR)\n img_1 = cv2.resize(img_1,(640,480))\n video_stack.append(img_1)\n cpickle.dump(video_stack,open('Webcam Test/image_stack.pickle','wb'))\n video_stack = video_stack[1:6]\n cv2.imshow('preview',img_1)\n k=cv2.waitKey(1)\n if k==27:\n cap.release()\n break\n #cv2.imwrite('Webcam Test/temp.jpg',img_1)\n print(\"fps = {}\".format(1/(time.time()-t1)))\n Disp.clear_output(wait=True)\ncap.release()\ncv2.destroyAllWindows()", | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "fps = 4.694025232194158\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"latex_envs": { | |
"autoclose": false, | |
"LaTeX_envs_menu_present": true, | |
"eqLabelWithNumbers": true, | |
"user_envs_cfg": false, | |
"report_style_numbering": false, | |
"cite_by": "apalike", | |
"bibliofile": "biblio.bib", | |
"labels_anchors": false, | |
"current_citInitial": 1, | |
"autocomplete": true, | |
"hotkeys": { | |
"equation": "Ctrl-E", | |
"itemize": "Ctrl-I" | |
}, | |
"latex_user_defs": false, | |
"eqNumInitial": 1 | |
}, | |
"language_info": { | |
"mimetype": "text/x-python", | |
"nbconvert_exporter": "python", | |
"name": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2", | |
"codemirror_mode": { | |
"version": 3, | |
"name": "ipython" | |
}, | |
"file_extension": ".py" | |
}, | |
"kernelspec": { | |
"name": "cv3", | |
"display_name": "cv3 (python3)", | |
"language": "python" | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "ABG/vision_related/libs/image-labelling-tool/examples/ssd/fasterrcnn_training_chainer.ipynb", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment