Skip to content

Instantly share code, notes, and snippets.

@ven-kyoshiro
Last active February 26, 2021 05:57
Show Gist options
  • Save ven-kyoshiro/553a0453a0f3b2624797f1335c4c7a78 to your computer and use it in GitHub Desktop.
Save ven-kyoshiro/553a0453a0f3b2624797f1335c4c7a78 to your computer and use it in GitHub Desktop.
video2activity_remove_dontmove_are.py
"""
# 変更点
- video2activityで画像を出力するフォルダを指定できるように
- video2activity('200922_4_10min.mp4','test.csv',visualize=True,visualize_dir='hoge')
- ↑の場合 hoge以下に画像が出力される
- 動画内でずっと黒色になっている部分を検出して,白く塗りつぶす処理を追加
- video2activity('200922_4_10min.mp4','test.csv',remove_dontmove=True)
"""
import tqdm
import time
import pandas as pd
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
def measure_position(arr,mouse_area_max,mouse_area_min,
eps ,min_sample):
v1 = mouse_area_max
v2 = mouse_area_min
img = np.mean(arr,axis=2)
tmp = (img>=v2)*(img<=v1)
x,y = np.where(tmp)
X = np.concatenate([x[:,np.newaxis],y[:,np.newaxis]],axis=1)
db = DBSCAN(eps=eps, min_samples=min_sample).fit(X)
labels = db.labels_
mouse_ids = np.where(labels==0)
mouse_xs = x[mouse_ids]
mouse_ys = y[mouse_ids]
return np.mean(mouse_xs),np.mean(mouse_ys),np.median(mouse_xs),np.median(mouse_ys), labels,X
def colorlize_mouse(labels,X,i,records,H,W,vis_dir="."):
plt.figure(figsize=(12, 8))
plt.xlim(0,H)# TODO:逆かも
plt.ylim(0,W)
# 塗り分けマウス全部ウツ
plt.scatter(X[:,0], X[:,1],c = labels,cmap='jet',s=0.5)
# マウスの中心をウツ
mouse_ids = np.where(labels==0)
mouse_xs = X[:,0][mouse_ids]
mouse_ys = X[:,1][mouse_ids]
plt.scatter([np.mean(mouse_xs)],[np.mean(mouse_ys)],color='lightgreen',marker='x')
plt.scatter([np.median(mouse_xs)],[np.median(mouse_ys)],color='orange',marker='x')
# 軌跡を描く
plt.plot(records['mean_x'],records['mean_y'],color='lightgreen')
plt.plot(records['med_x'],records['med_y'],color='orange')
plt.savefig(f'{vis_dir}/mouse_img_{i}.png',dpi=200)
plt.close()
def get_dontmove_mask(video_name,binalize_th=70,dontmove_rate=0.9,skip_frame=30):
""" 動画から不動領域のマスクを取得する
Args:
video_name : 読み込みたいvideoのパス
binalize_th : その値以下は白(1)とみなす
dontmove_rate : 動画のうち,この値以上の割合で白だったら,不動領域とみなす
skip_frame : skip_frameごとに読み込み,それ以外は飛ばす
Returns:
binary mask: np.array, .shape=(H,W),値域={0,1}
1が不動領域,0がそうでない
"""
print('detecting dontmove_area')
cap = cv2.VideoCapture(video_name)
# 基本情報を取得
frame_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 合計値を保存する行列を容易
sum_mask = np.zeros((H,W))
count = 0
for i in tqdm.tqdm(range(frame_len)):
ret, frame = cap.read()
if not i%skip_frame==0:
continue
if ret == False:
break
frame = np.reshape(frame,(H,W,3))
img = np.mean(frame,axis=2)
sum_mask += (img < binalize_th)
count += 1
white_ratio_mask = sum_mask/count
# 白の比率が閾値より高い領域をTrue(1), それ以外をFalse(0)
return white_ratio_mask > dontmove_rate
def video2activity(video_name,csv_name,interval=300,mouse_area_max=70,mouse_area_min=30,
eps = 20,min_sample = 700,visualize=False,visualize_dir="visualize_result",remove_dontmove=True,
skip_frame=30):
"""
skip_frame: 不動マスクを作る際に,何フレームごとに読むか
"""
st = time.time()
if remove_dontmove:
dontmove_area = get_dontmove_mask(video_name,skip_frame=skip_frame)
if visualize:
plt.figure(dpi=100)
plt.imshow(dontmove_area.T)
plt.savefig(f'{visualize_dir}/dontmove_area.png',dpi=200)
plt.close()
dontmove_area = dontmove_area[np.newaxis,:,:,np.newaxis]#shape=(1,H,W,1)
if visualize:
os.makedirs(visualize_dir, exist_ok=True)
count = 0
cap = cv2.VideoCapture(video_name)
frame_len = cap.get(cv2.CAP_PROP_FRAME_COUNT)
assert cap.isOpened(),'could not read the video'
records = {
'frame':[],
'mean_activity':[0,],
'med_activity':[0,],
'mean_x':[],
'mean_y':[],
'med_x':[],
'med_y':[],
}
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
for i in tqdm.tqdm(range(int(frame_len))):
ret, frame = cap.read()
if ret == False :
break
if i%interval == 0:
# frame -> (1,H,W,3)
frame = np.reshape(frame,(1,H,W,3))
if remove_dontmove:
frame = np.clip(dontmove_area*255.+frame,0,255)# 不動領域を強制的に255(白)にする
mean_x, mean_y, med_x, med_y, labels,X = measure_position(
frame[0],mouse_area_max=mouse_area_max,
mouse_area_min=mouse_area_min, eps = eps,min_sample = min_sample)
records['frame'].append(i)
records['mean_x'].append(mean_x)
records['mean_y'].append(mean_y)
records['med_x'].append(med_x)
records['med_y'].append(med_y)
if len(records['mean_x'])>1:
records['mean_activity'].append(np.sqrt(
(records['mean_x'][-2]-mean_x)**2 + (records['mean_y'][-2]-mean_y)**2
))
records['med_activity'].append(np.sqrt(
(records['med_x'][-2]-med_x)**2 + (records['med_y'][-2]-med_y)**2
))
if visualize:
colorlize_mouse(labels,X,i,records,H,W,vis_dir=visualize_dir)
cap.release()
print(f'it takes {time.time()-st}[sec]')
df = pd.DataFrame(records)
df.to_csv(csv_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment