Created
March 11, 2024 06:10
-
-
Save dineshdharme/97cc763d57a9088726aa56c7b8eb174f to your computer and use it in GitHub Desktop.
An example to demonstrate using of Pyspark in video processing.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I have adapted the following jupyter notebook to show how spark can do video processing at scale. | |
https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1969271421694072/3760413548916830/5612335034456173/latest.html | |
You need to install python libraries in your conda environment. Also make sure you have ffmpeg library installed natively: | |
`pip install ffmpeg-python` | |
`pip install face-recognition` | |
`conda install -c conda-forge opencv` | |
Download a `.mp4` video with face in it to perform face detection according to the following code. | |
`https://www.videezy.com/free-video/face?format-mp4=true` | |
Following the pyspark code : | |
from pyspark import SQLContext, SparkConf, SparkContext | |
from pyspark.sql import SparkSession | |
import pyspark.sql.functions as F | |
conf = SparkConf().setAppName("myApp").setMaster("local[40]") | |
spark = SparkSession.builder.master("local[40]").config("spark.driver.memory", "30g").getOrCreate() | |
sc = spark.sparkContext | |
sqlContext = SQLContext(sc) | |
import cv2 | |
import os | |
import uuid | |
import ffmpeg | |
import subprocess | |
import numpy as np | |
from scipy.optimize import linear_sum_assignment | |
import pyspark.sql.functions as F | |
from pyspark.sql import Row | |
from pyspark.sql.types import (StructType, StructField, | |
IntegerType, FloatType, | |
ArrayType, BinaryType, | |
MapType, DoubleType, StringType) | |
from pyspark.sql.window import Window | |
from pyspark.ml.feature import StringIndexer | |
from pyspark.sql import Row, DataFrame, SparkSession | |
import pathlib | |
videos = [] | |
input_dir = "../data/video_files/faces/" | |
pathlist = list(pathlib.Path(input_dir).glob('*.mp4')) | |
pathlist = [Row(str(ele)) for ele in pathlist] | |
print(pathlist) | |
column_name = ["video_uri"] | |
df = sqlContext.createDataFrame(data=pathlist, schema=column_name) | |
print("Initial dataframe") | |
df.show(10, truncate=False) | |
video_metadata = StructType([ | |
StructField("width", IntegerType(), False), | |
StructField("height", IntegerType(), False), | |
StructField("num_frames", IntegerType(), False), | |
StructField("duration", FloatType(), False) | |
]) | |
shots_schema = ArrayType( | |
StructType([ | |
StructField("start", FloatType(), False), | |
StructField("end", FloatType(), False) | |
])) | |
@F.udf(returnType=video_metadata) | |
def video_probe(uri): | |
probe = ffmpeg.probe(uri, threads=1) | |
video_stream = next( | |
( | |
stream | |
for stream in probe["streams"] | |
if stream["codec_type"] == "video" | |
), | |
None, | |
) | |
width = int(video_stream["width"]) | |
height = int(video_stream["height"]) | |
num_frames = int(video_stream["nb_frames"]) | |
duration = float(video_stream["duration"]) | |
return (width, height, num_frames, duration) | |
@F.udf(returnType=ArrayType(BinaryType())) | |
def video2images(uri, width, height, | |
sample_rate: int = 5, | |
start: float = 0.0, | |
end: float = -1.0, | |
n_channels: int = 3): | |
""" | |
Uses FFmpeg filters to extract image byte arrays | |
and sampled & localized to a segment of video in time. | |
""" | |
video_data, _ = ( | |
ffmpeg.input(uri, threads=1) | |
.output( | |
"pipe:", | |
format="rawvideo", | |
pix_fmt="rgb24", | |
ss=start, | |
t=end - start, | |
r=1 / sample_rate, | |
).run(capture_stdout=True)) | |
img_size = height * width * n_channels | |
return [video_data[idx:idx + img_size] for idx in range(0, len(video_data), img_size)] | |
df = df.withColumn("metadata", video_probe(F.col("video_uri"))) | |
print("With Metadata") | |
df.show(10, truncate=False) | |
df = df.withColumn("frame", F.explode( | |
video2images(F.col("video_uri"), F.col("metadata.width"), F.col("metadata.height"), F.lit(1), F.lit(0.0), | |
F.lit(5.0)))) | |
import face_recognition | |
box_struct = StructType( | |
[ | |
StructField("xmin", IntegerType(), False), | |
StructField("ymin", IntegerType(), False), | |
StructField("xmax", IntegerType(), False), | |
StructField("ymax", IntegerType(), False) | |
] | |
) | |
def bbox_helper(bbox): | |
top, right, bottom, left = bbox | |
bbox = [top, left, bottom, right] | |
return list(map(lambda x: max(x, 0), bbox)) | |
@F.udf(returnType=ArrayType(box_struct)) | |
def face_detector(img_data, width=1920, height=1080, n_channels=3): | |
img = np.frombuffer(img_data, np.uint8).reshape(height, width, n_channels) | |
faces = face_recognition.face_locations(img) | |
return [bbox_helper(f) for f in faces] | |
df = df.withColumn("faces", face_detector(F.col("frame"), F.col("metadata.width"), F.col("metadata.height"))) | |
annot_schema = ArrayType( | |
StructType( | |
[ | |
StructField("bbox", box_struct, False), | |
StructField("tracker_id", StringType(), False), | |
] | |
) | |
) | |
def bbox_iou(b1, b2): | |
L = list(zip(b1, b2)) | |
left, top = np.max(L, axis=1)[:2] | |
right, bottom = np.min(L, axis=1)[2:] | |
if right < left or bottom < top: | |
return 0 | |
b_area = lambda b: (b[2] - b[0]) * (b[3] - b[1]) | |
inter_area = b_area([left, top, right, bottom]) | |
b1_area, b2_area = b_area(b1), b_area(b2) | |
iou = inter_area / float(b1_area + b2_area - inter_area) | |
return iou | |
@F.udf(returnType=MapType(IntegerType(), IntegerType())) | |
def tracker_match(trackers, detections, bbox_col="bbox", threshold=0.3): | |
""" | |
Match Bounding Boxes across successive image frames. | |
Parameters | |
---------- | |
trackers : List of Box2dType with str identifier | |
A column of tracked objects. | |
detections: List of Box2dType without tracker id matching | |
The list of unmatched detections. | |
bbox_col: str | |
A string to name the column of bounding boxes. | |
threshold : Float | |
IOU of Box2d objects exceeding threshold will be matched. | |
Return | |
------ | |
MapType | |
Returns a MapType matching indices of trackers and detections. | |
""" | |
from scipy.optimize import linear_sum_assignment | |
similarity = bbox_iou # lambda a, b: a.iou(b) | |
if not trackers or not detections: | |
return {} | |
if len(trackers) == len(detections) == 1: | |
if ( | |
similarity(trackers[0][bbox_col], detections[0][bbox_col]) | |
>= threshold | |
): | |
return {0: 0} | |
sim_mat = np.array( | |
[ | |
[ | |
similarity(tracked[bbox_col], detection[bbox_col]) | |
for tracked in trackers | |
] | |
for detection in detections | |
], | |
dtype=np.float32, | |
) | |
matched_idx = linear_sum_assignment(-sim_mat) | |
matches = [] | |
for m in matched_idx: | |
try: | |
if sim_mat[m[0], m[1]] >= threshold: | |
matches.append(m.reshape(1, 2)) | |
except: | |
pass | |
if len(matches) == 0: | |
return {} | |
else: | |
matches = np.concatenate(matches, axis=0, dtype=int) | |
rows, cols = zip(*np.where(matches)) | |
idx_map = {cols[idx]: rows[idx] for idx in range(len(rows))} | |
return idx_map | |
@F.udf(returnType=ArrayType(box_struct)) | |
def OFMotionModel(frame, prev_frame, bboxes, height, width): | |
if not prev_frame: | |
prev_frame = frame | |
gray = cv2.cvtColor(np.frombuffer(frame, np.uint8).reshape(height, width, 3), cv2.COLOR_BGR2GRAY) | |
prev_gray = cv2.cvtColor(np.frombuffer(prev_frame, np.uint8).reshape(height, width, 3), cv2.COLOR_BGR2GRAY) | |
inst = cv2.DISOpticalFlow.create(cv2.DISOPTICAL_FLOW_PRESET_MEDIUM) | |
inst.setUseSpatialPropagation(False) | |
flow = inst.calc(prev_gray, gray, None) | |
h, w = flow.shape[:2] | |
shifted_boxes = [] | |
for box in bboxes: | |
xmin, ymin, xmax, ymax = box | |
avg_y = np.mean(flow[int(ymin):int(ymax), int(xmin):int(xmax), 0]) | |
avg_x = np.mean(flow[int(ymin):int(ymax), int(xmin):int(xmax), 1]) | |
shifted_boxes.append( | |
{"xmin": int(max(0, xmin + avg_x)), "ymin": int(max(0, ymin + avg_y)), "xmax": int(min(w, xmax + avg_x)), | |
"ymax": int(min(h, ymax + avg_y))}) | |
return shifted_boxes | |
def match_annotations(iterator, segment_id="video_uri", id_col="tracker_id"): | |
""" | |
Used by mapPartitions to iterate over the small chunks of our hierarchically-organized data. | |
""" | |
matched_annots = [] | |
for idx, data in enumerate(iterator): | |
data = data[1] | |
if not idx: | |
old_row = {idx: uuid.uuid4() for idx in range(len(data[1]))} | |
old_row[segment_id] = data[0] | |
pass | |
annots = [] | |
curr_row = {segment_id: data[0]} | |
if old_row[segment_id] != curr_row[segment_id]: | |
old_row = {} | |
if data[2] is not None: | |
for ky, vl in data[2].items(): | |
detection = data[1][vl].asDict() | |
detection[id_col] = old_row.get(ky, uuid.uuid4()) | |
curr_row[vl] = detection[id_col] | |
annots.append(Row(**detection)) | |
matched_annots.append(annots) | |
old_row = curr_row | |
return matched_annots | |
def track_detections(df, segment_id="video_uri", frames="frame", detections="faces", optical_flow=True): | |
id_col = "tracker_id" | |
frame_window = Window().orderBy(frames) | |
value_window = Window().orderBy("value") | |
annot_window = Window.partitionBy(segment_id).orderBy(segment_id, frames) | |
indexer = StringIndexer(inputCol=segment_id, outputCol="vidIndex") | |
# adjust detections w/ optical flow | |
if optical_flow: | |
df = ( | |
df.withColumn("prev_frames", F.lag(F.col(frames)).over(annot_window)) | |
.withColumn(detections, OFMotionModel(F.col(frames), F.col("prev_frames"), F.col(detections), F.col("metadata.height"), F.col("metadata.width"))) | |
) | |
df = ( | |
df.select(segment_id, frames, detections) | |
.withColumn("bbox", F.explode(detections)) | |
.withColumn(id_col, F.lit("")) | |
.withColumn("trackables", F.struct([F.col("bbox"), F.col(id_col)])) | |
.groupBy(segment_id, frames, detections) | |
.agg(F.collect_list("trackables").alias("trackables")) | |
.withColumn( | |
"old_trackables", F.lag(F.col("trackables")).over(annot_window) | |
) | |
.withColumn( | |
"matched", | |
tracker_match(F.col("trackables"), F.col("old_trackables")), | |
) | |
.withColumn("frame_index", F.row_number().over(frame_window)) | |
) | |
df = ( | |
indexer.fit(df) | |
.transform(df) | |
.withColumn("vidIndex", F.col("vidIndex").cast(StringType())) | |
) | |
unique_ids = df.select("vidIndex").distinct().count() | |
matched = ( | |
df.select("vidIndex", segment_id, "trackables", "matched") | |
.rdd.map(lambda x: (x[0], x[1:])) | |
.partitionBy(unique_ids, lambda x: int(x[0])) | |
.mapPartitions(match_annotations) | |
) | |
matched_annotations = sqlContext.createDataFrame(matched, annot_schema).withColumn("value_index", | |
F.row_number().over( | |
value_window)) | |
return ( | |
df.join(matched_annotations, F.col("value_index") == F.col("frame_index")) | |
.withColumnRenamed("value", "trackers_matched") | |
.withColumn("tracked", F.explode(F.col("trackers_matched"))) | |
.select( | |
segment_id, | |
frames, | |
detections, | |
F.col("tracked.{}".format("bbox")).alias("bbox"), | |
F.col("tracked.{}".format(id_col)).alias(id_col), | |
) | |
.withColumn(id_col, F.sha2(F.concat(F.col(segment_id), F.col(id_col)), 256)) | |
.withColumn("tracked_detections", F.struct([F.col("bbox"), F.col(id_col)])) | |
.groupBy(segment_id, frames, detections) | |
.agg(F.collect_list("tracked_detections").alias("tracked_detections")) | |
.orderBy(segment_id, frames, detections) | |
) | |
from pyspark import keyword_only | |
from pyspark.ml.pipeline import Transformer | |
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param | |
class DetectionTracker(Transformer, HasInputCol, HasOutputCol): | |
"""Detect and track.""" | |
@keyword_only | |
def __init__(self, inputCol=None, outputCol=None, framesCol=None, detectionsCol=None, optical_flow=None): | |
"""Initialize.""" | |
super(DetectionTracker, self).__init__() | |
self.framesCol = Param(self, "framesCol", "Column containing frames.") | |
self.detectionsCol = Param(self, "detectionsCol", "Column containing detections.") | |
self.optical_flow = Param(self, "optical_flow", "Use optical flow for tracker correction. Default is False") | |
self._setDefault(framesCol="frame", detectionsCol="faces", optical_flow=False) | |
kwargs = self._input_kwargs | |
self.setParams(**kwargs) | |
@keyword_only | |
def setParams(self, inputCol=None, outputCol=None, framesCol=None, detectionsCol=None, optical_flow=None): | |
"""Get params.""" | |
kwargs = self._input_kwargs | |
return self._set(**kwargs) | |
def setFramesCol(self, value): | |
"""Set framesCol.""" | |
return self._set(framesCol=value) | |
def getFramesCol(self): | |
"""Get framesCol.""" | |
return self.getOrDefault(self.framesCol) | |
def setDetectionsCol(self, value): | |
"""Set detectionsCol.""" | |
return self._set(detectionsCol=value) | |
def getDetectionsCol(self): | |
"""Get detectionsCol.""" | |
return self.getOrDefault(self.detectionsCol) | |
def setOpticalflow(self, value): | |
"""Set optical_flow.""" | |
return self._set(optical_flow=value) | |
def getOpticalflow(self): | |
"""Get optical_flow.""" | |
return self.getOrDefault(self.optical_flow) | |
def _transform(self, dataframe): | |
"""Do transformation.""" | |
input_col = self.getInputCol() | |
output_col = self.getOutputCol() | |
frames_col = self.getFramesCol() | |
detections_col = self.getDetectionsCol() | |
optical_flow = self.getOpticalflow() | |
id_col = "tracker_id" | |
frame_window = Window().orderBy(frames_col) | |
value_window = Window().orderBy("value") | |
annot_window = Window.partitionBy(input_col).orderBy(input_col, frames_col) | |
indexer = StringIndexer(inputCol=input_col, outputCol="vidIndex") | |
# adjust detections w/ optical flow | |
if optical_flow: | |
dataframe = ( | |
dataframe.withColumn("prev_frames", F.lag(F.col(frames_col)).over(annot_window)) | |
.withColumn(detections_col, | |
OFMotionModel(F.col(frames_col), F.col("prev_frames"), F.col(detections_col))) | |
) | |
dataframe = ( | |
dataframe.select(input_col, frames_col, detections_col) | |
.withColumn("bbox", F.explode(detections_col)) | |
.withColumn(id_col, F.lit("")) | |
.withColumn("trackables", F.struct([F.col("bbox"), F.col(id_col)])) | |
.groupBy(input_col, frames_col, detections_col) | |
.agg(F.collect_list("trackables").alias("trackables")) | |
.withColumn( | |
"old_trackables", F.lag(F.col("trackables")).over(annot_window) | |
) | |
.withColumn( | |
"matched", | |
tracker_match(F.col("trackables"), F.col("old_trackables")), | |
) | |
.withColumn("frame_index", F.row_number().over(frame_window)) | |
) | |
dataframe = ( | |
indexer.fit(dataframe) | |
.transform(dataframe) | |
.withColumn("vidIndex", F.col("vidIndex").cast(StringType())) | |
) | |
unique_ids = dataframe.select("vidIndex").distinct().count() | |
matched = ( | |
dataframe.select("vidIndex", input_col, "trackables", "matched") | |
.rdd.map(lambda x: (x[0], x[1:])) | |
.partitionBy(unique_ids, lambda x: int(x[0])) | |
.mapPartitions(match_annotations) | |
) | |
matched_annotations = sqlContext.createDataFrame(matched, annot_schema).withColumn("value_index", | |
F.row_number().over( | |
value_window)) | |
return ( | |
dataframe.join(matched_annotations, F.col("value_index") == F.col("frame_index")) | |
.withColumnRenamed("value", "trackers_matched") | |
.withColumn("tracked", F.explode(F.col("trackers_matched"))) | |
.select( | |
input_col, | |
frames_col, | |
detections_col, | |
F.col("tracked.{}".format("bbox")).alias("bbox"), | |
F.col("tracked.{}".format(id_col)).alias(id_col), | |
) | |
.withColumn(id_col, F.sha2(F.concat(F.col(input_col), F.col(id_col)), 256)) | |
.withColumn(output_col, F.struct([F.col("bbox"), F.col(id_col)])) | |
.groupBy(input_col, frames_col, detections_col) | |
.agg(F.collect_list(output_col).alias(output_col)) | |
.orderBy(input_col, frames_col, detections_col) | |
) | |
detectTracker = DetectionTracker(inputCol="video_uri", outputCol="tracked_detections") | |
print(type(detectTracker)) | |
detectTracker.transform(df) | |
final = track_detections(df) | |
print("Final dataframe") | |
final.select("tracked_detections").show(100, truncate=False) | |
Output : | |
[<Row('../data/video_files/faces/production_id_3761466 (2160p).mp4')>] | |
Initial dataframe | |
+-----------------------------------------------------------+ | |
|video_uri | | |
+-----------------------------------------------------------+ | |
|../data/video_files/faces/production_id_3761466 (2160p).mp4| | |
+-----------------------------------------------------------+ | |
With Metadata | |
+-----------------------------------------------------------+------------------------+ | |
|video_uri |metadata | | |
+-----------------------------------------------------------+------------------------+ | |
|../data/video_files/faces/production_id_3761466 (2160p).mp4|{3840, 2160, 288, 11.52}| | |
+-----------------------------------------------------------+------------------------+ | |
<class '__main__.DetectionTracker'> | |
+---------------------------------------------------------------------------------------------+ | |
|tracked_detections | | |
+---------------------------------------------------------------------------------------------+ | |
|[{{649, 1810, 1204, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]| | |
|[{{678, 1777, 1233, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]| | |
|[{{725, 1774, 1280, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]| | |
|[{{728, 1760, 1283, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]| | |
+---------------------------------------------------------------------------------------------+ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment