Skip to content

Instantly share code, notes, and snippets.

@Pked01
Created January 6, 2019 11:15
Show Gist options
  • Save Pked01/8115f6c1e3c4c625ae7d39c548e00407 to your computer and use it in GitHub Desktop.
Save Pked01/8115f6c1e3c4c625ae7d39c548e00407 to your computer and use it in GitHub Desktop.
ABG/vision_related/libs/image-labelling-tool/examples/ssd/fasterrcnn_training_chainer.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:16:25.678729Z",
"end_time": "2019-01-06T09:16:33.848508Z"
},
"trusted": true
},
"cell_type": "code",
"source": "import argparse\nimport matplotlib.pyplot as plot\nimport yaml,pickle\n\nimport chainer,cv2,time,datetime,os\n\n#from chainercv.datasets import voc_detection_label_names\nfrom chainercv.links import SSD300\nfrom chainercv.links import FasterRCNNVGG16\nfrom chainercv import utils\nfrom chainercv.visualizations import vis_bbox,vis_image\n\nfrom original_detection_dataset import OriginalDetectionDataset\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom chainer import serializers\nimport IPython.display as Disp\n\nimport pandas as pd\n\nimport _pickle as cpickle",
"execution_count": 2,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### loading model"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "- SSD"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-12-06T08:23:20.749435Z",
"start_time": "2018-12-06T08:23:15.887270Z"
},
"trusted": true
},
"cell_type": "code",
"source": "label_names = ('helmet','no_helmet','vest','no_vest','person')\nmodel = SSD300(\n n_fg_class=len(label_names),\n pretrained_model='result_with_human/model_iter_74000')",
"execution_count": 2,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "- Faster-RCNN"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:35:53.784653Z",
"end_time": "2019-01-06T09:35:57.573046Z"
},
"trusted": true
},
"cell_type": "code",
"source": "label_names = ('person','helmet','no_helmet','vest','no_vest')\nmodel = FasterRCNNVGG16(\n n_fg_class=len(label_names),\n pretrained_model='model_faster_rcnn_new/model_iter_120000')",
"execution_count": 22,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### setting gpu usage"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:35:57.574988Z",
"end_time": "2019-01-06T09:35:57.863316Z"
},
"trusted": true
},
"cell_type": "code",
"source": "chainer.cuda.get_device_from_id(0).use()\nmodel.to_gpu()",
"execution_count": 23,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 23,
"data": {
"text/plain": "<chainercv.links.model.faster_rcnn.faster_rcnn_vgg.FasterRCNNVGG16 at 0x7fb5251e5240>"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### setting theshold"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:16:42.063695Z",
"end_time": "2019-01-06T09:16:42.066500Z"
},
"trusted": true
},
"cell_type": "code",
"source": "model.score_thresh =.6",
"execution_count": 5,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### check intersection between two bbox"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:16:42.068909Z",
"end_time": "2019-01-06T09:16:42.176351Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def check_intersection(a, b):\n \"\"\"\n a : rectangle in (x,y,w,h)\n \"\"\"\n x = max(a[0], b[0])\n y = max(a[1], b[1])\n w = min(a[0]+a[2], b[0]+b[2]) - x\n h = min(a[1]+a[3], b[1]+b[3]) - y\n if w<0 or h<0: \n return False,() # or (0,0,0,0) ?\n return True,(x, y, w, h)",
"execution_count": 6,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### converting bbox format"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:16:42.178097Z",
"end_time": "2019-01-06T09:16:42.260653Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def convert_y1x1y2x2toxywh(rect):\n return [rect[1],rect[0],rect[3]-rect[1],rect[2]-rect[0]]\n ",
"execution_count": 7,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### functionality to create bbox"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:29:56.089845Z",
"end_time": "2019-01-06T09:29:56.167940Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def create_bbox_img(img,bboxes,labels,scores,lower_side=False,use_middle=False,use_person_filter=False):\n \n font = cv2.FONT_HERSHEY_SIMPLEX\n bboxes = bboxes[0].astype(int).tolist()\n orig_scores = scores\n scores = [np.array([round(i,1) for i in orig_scores[0]])]\n if use_person_filter:\n failed_idx = []\n # 1. check if person is three\n # 2. if there's no person there's shouldn't be any detection\n if 4 in labels[0]:#4 is index for person\n idx_person = np.where(labels[0]==4)[0].tolist()\n for i in range(len(bboxes)): # for each index of bbox\n for j in idx_person: # for each person detection\n if check_intersection(convert_y1x1y2x2toxywh(bboxes[i]),convert_y1x1y2x2toxywh(bboxes[j]))[0]:\n # if intersection is positive \n try:\n failed_idx.remove(i) # remove i from failed index\n except:\n pass\n break # breaking out of loop\n else:\n failed_idx.append(i)\n else: # if no person detected\n failed_idx = list(range(len(labels)))\n \n \n for idx,i in enumerate(bboxes):\n # if in failed index then no show\n if use_person_filter:\n if idx in failed_idx:\n continue # do not print on screen\n #bbox\n #('helmet', 'no_helmet', 'vest', 'no_vest','person')\n bbox_colors = [(0,0,0),(0,255,0),(0,0,255),(0,255,0),(0,0,255)]\n cv2.rectangle(img,(i[1],i[0]),(i[3],i[2]),bbox_colors[labels[0][idx]],2)\n ##text\n #print(i[0],i[1],type(i[0]))\n\n \n \n if lower_side:\n cv2.putText(img,str(label_names[labels[0][(idx)]]) +' : '+str(scores[0][(idx)]), (i[1],i[2]+20), font, 0.8, (0, 0, 255),1, cv2.LINE_AA)\n elif use_middle:\n cv2.putText(img,str(label_names[labels[0][(idx)]]) +' : '+str(scores[0][(idx)]), (i[1],int(.5*(i[0]+i[2]))), font, 0.8, (0, 0, 255),1, cv2.LINE_AA)\n else: \n cv2.putText(img,str(label_names[labels[0][(idx)]]) +' : '+str(scores[0][(idx)]), (i[1],i[0]-20), font, 0.8, (0, 0, 255),1, cv2.LINE_AA)\n \n return img",
"execution_count": 18,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### logging function"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:16:43.187998Z",
"end_time": "2019-01-06T09:16:43.270026Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def logging_class(object):\n def __init__(direc_path = None,threshold_time = 5):\n \"\"\"\n direc_path : path where you want to log information \n threshold_time : time for detection to start(Seconds)\n \n \"\"\"\n \n os.environ['TZ'] = \"Asia/Calcutta\"\n if os.direc_path is None:\n self.cwd = os.getcwd()\n else:\n self.cwd = direc_path\n \n self.threshold_time = threshold_time\n os.makedirs(os.path.join(self.cwd,\"video_alerts\"),exist_ok=True)\n os.makedirs(os.path.join(self.cwd,\"text_alerts\"),exist_ok=True)\n \n self.vid_writer = None # initializing video writer object\n self.file_writer = None # initialize text writer\n self.vid_filename = None\n self.csv_file_name = None\n self.violation_timer\n \n \n def __write_video(self,flag = 'start',fps=20,img_size=(640,480)):\n \"\"\"\n flag : start,stop want to start writing video\n fps : fps rate of vide\n img_size : img size in width, height\n \n vid_writer object will be created when violation created and it will be release only if \n threshold time is over\n \"\"\"\n if flag=='start':\n if self.vid_writer is None:\n self.vid_filename = os.path.join(self.cwd,\"video_alerts\",datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\"))+'.mp4'\n self.vid_writer = cv2.VideoWriter(self.vid_filename,cv2.VideoWriter_fourcc(*'X264'), fps,img_size)\n if flag=='stop':\n self.vid_writer.release()\n self.vid_writer = None\n def __write_file(self,flag = 'start'):\n \"\"\"\n flag= start/stop\n \"\"\"\n if flag=='start':\n if self.file_writer is None:\n self.csv_file_name = os.path.join(self.cwd,\"text_alerts\",datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\"))+'.csv'\n #self.file_writer = open(datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\"),\"w+\") \n if flag=='stop':\n self.file_writer.close()\n self.file_writer = None\n def add_frames(self,image):\n if self.vid_writer is None:\n self.__write_video('start')\n self.vid_writer.write(image)\n def add_info2file(self,list_data):\n pass\n\n\n \n \n ",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-10-12T09:10:38.057224Z",
"start_time": "2018-10-12T09:09:09.931933Z"
}
},
"cell_type": "markdown",
"source": "#### write video"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:31:12.676627Z",
"end_time": "2019-01-06T09:31:17.264383Z"
},
"trusted": true
},
"cell_type": "code",
"source": "#out = cv2.VideoWriter('test_vid/test_vid_2_op.mp4',cv2.VideoWriter_fourcc('M','J','P','G'), 20, (640,480))\nout = cv2.VideoWriter('/home/prateek/Downloads/Notebooks/output_videos/Whatsapp_plant_op.mp4',cv2.VideoWriter_fourcc(*'MJPG'), 20, (640,480))\n# for i in files:H264\n# try:\n# img = utils.read_image(path+'/'+i, color=True)\n# except Exception as e:\n# continue\n# print(e)\n\ncap = cv2.VideoCapture('WhatsApp Video 2018-10-21 at 9.51.59 AM.mp4')\n#cap.set(1,100)\nwhile True: \n try:\n ret,frame = cap.read()\n frame = cv2.rotate(frame,0)\n if not ret:\n break\n img_orig = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n img = np.rollaxis(img_orig, 2, 0) \n\n #img = utils.read_image(path+'/'+i, color=True)\n except Exception as e:\n print(e)\n cap.release()\n out.release()\n continue\n\n bboxes, labels, scores = model.predict([img])\n \n #img = cv2.imread(path+'/'+i)///''\n\n img_1 = create_bbox_img(frame,bboxes,labels,scores,lower_side=True,use_middle=False)\n \n# plt.imshow(img_orig)\n# plt.show()\n# print(len(bboxes[0]))\n# Disp.clear_output(wait=True)\n #img_1 = cv2.cvtColor(img_1,cv2.COLOR_RGB2BGR)\n img_1 = cv2.resize(img_1,(640,480))\n out.write(img_1)\n\nout.release()",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-12-05T10:20:43.507598Z",
"start_time": "2018-12-05T10:20:43.499413Z"
},
"trusted": true
},
"cell_type": "code",
"source": "out.release()",
"execution_count": 12,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### display video from webcam\n-- SSD\n- average fps in cpu mode is .5 (i7 7th gen)\n- average fps in gpu mode is 28 (gtx 1060)\n\n-- Faster-RCNN\n\n- average FPS in gpu mode is 5.3"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:40:41.276209Z",
"end_time": "2019-01-06T09:40:41.281320Z"
},
"trusted": true
},
"cell_type": "code",
"source": "model.score_thresh =.8",
"execution_count": 32,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-01-06T09:41:21.593784Z",
"end_time": "2019-01-06T09:42:01.700950Z"
},
"trusted": true
},
"cell_type": "code",
"source": "cv2.namedWindow('preview',cv2.WINDOW_NORMAL)\ncap = cv2.VideoCapture('Video_Data/Video/4.Hot_mill_view/Camera6_spandan office_spandan office_20181219094507_20181219094549_3182231.mp4')\n#log_info = logging_class()\n#cap.set(1,100)\nfps = int(cap.get(5))\n#file_name = datetime.datetime.now().strftime(\"%y_%m_%d_%H_%M_%S\")+'.csv'\nfile_name = 'Webcam Test/output.csv'\ndf = pd.DataFrame(columns=label_names)\ndf.to_csv(file_name) \nframe_no = 0\nop = [] # file_output\nvideo_stack = []\nwhile True:\n frame_no+=1\n t1 = time.time()\n try:\n # clearing buffer \n for i in range(6):\n ret,frame = cap.read()\n #frame = cv2.rotate(frame,0)\n if not ret:\n break\n img_orig = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n img = np.rollaxis(img_orig, 2, 0) \n\n #img = utils.read_image(path+'/'+i, color=True)\n except Exception as e:\n print(e)\n cap.release()\n continue\n\n bboxes, labels, scores = model.predict([img])\n \n ############ writing all data ###############\n if frame_no%(int(fps*.5))==0: # seconds(1) to reset\n df = pd.DataFrame(pd.DataFrame(op,columns=label_names).median()).T\n df.to_csv(file_name,mode='a',header=False)\n op = []\n row = [0]*len(label_names)\n for i in labels[0]:\n row[i]+=1\n op.append(row)\n ############ data writing logic #############\n \n #############################################\n \n #img = cv2.imread(path+'/'+i)///''\n\n img_1 = create_bbox_img(frame,bboxes,labels,scores,lower_side=False,use_middle=False,use_person_filter=False)\n \n# plt.imshow(img_orig)\n# plt.show()\n# print(len(bboxes[0]))\n# Disp.clear_output(wait=True)\n #img_1 = cv2.cvtColor(img_1,cv2.COLOR_RGB2BGR)\n img_1 = cv2.resize(img_1,(640,480))\n video_stack.append(img_1)\n cpickle.dump(video_stack,open('Webcam Test/image_stack.pickle','wb'))\n video_stack = video_stack[1:6]\n cv2.imshow('preview',img_1)\n k=cv2.waitKey(1)\n if k==27:\n cap.release()\n break\n #cv2.imwrite('Webcam Test/temp.jpg',img_1)\n print(\"fps = {}\".format(1/(time.time()-t1)))\n Disp.clear_output(wait=True)\ncap.release()\ncv2.destroyAllWindows()",
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"text": "fps = 4.694025232194158\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"hide_input": false,
"latex_envs": {
"autoclose": false,
"LaTeX_envs_menu_present": true,
"eqLabelWithNumbers": true,
"user_envs_cfg": false,
"report_style_numbering": false,
"cite_by": "apalike",
"bibliofile": "biblio.bib",
"labels_anchors": false,
"current_citInitial": 1,
"autocomplete": true,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"latex_user_defs": false,
"eqNumInitial": 1
},
"language_info": {
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"name": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"file_extension": ".py"
},
"kernelspec": {
"name": "cv3",
"display_name": "cv3 (python3)",
"language": "python"
},
"gist": {
"id": "",
"data": {
"description": "ABG/vision_related/libs/image-labelling-tool/examples/ssd/fasterrcnn_training_chainer.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment