Last active
March 14, 2022 12:04
-
-
Save AniKar/056b379671b1d88db832d2508a28dfdb to your computer and use it in GitHub Desktop.
Video frame sequence DataLoader in Pytorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import random | |
import numpy as np | |
from math import ceil | |
import os | |
import os.path | |
import pafy | |
import cv2 as cv | |
class VideoFrameDataset(data.Dataset): | |
""" | |
Custom Dataset class for on-demand loading of sequences of frames from given videos. | |
""" | |
def __init__(self, video_urls: Dict[int, str], video_ivals: Dict[int, str], data_dir: str, | |
download: bool=True, train_split: Optional[int]=None, train: bool=True, | |
transform: Optional[transforms.Compose]=None, sample_step: int=2, | |
skip_step: int=3, n_frames: int=3): | |
""" | |
param video_urls: dictinary of the dataset video urls. | |
param video_ivals: dictionary of valid intervals per video. | |
param data_dir: local directory to store downloaded video files. | |
param download: specifies whether to download the videos to local dir or not. | |
param train_split: number of samples to be used for training. | |
param train: train or test dataset. | |
param transform: transformation to be applied to each frame. | |
param sample_step: step between frames at each sample. | |
param skip_step: number of frames to skip between samples. | |
param n_frames: number of input/target frames in each sample. | |
""" | |
self.n_frames = n_frames | |
self.sample_step = sample_step | |
self.skip_step = skip_step | |
self.transform = transform | |
self.video_caps = dict() | |
self.video_sample_sizes = dict() | |
self.video_start_frames = dict() | |
self.video_end_frames = dict() | |
self.dataset_size = 0 | |
prefix = 'train' if train else 'test' | |
clip_length = n_frames*2 | |
for i, url in video_urls.items(): | |
video = pafy.new(url) | |
video_stream = video.getbest(preftype='mp4') | |
if download: | |
fname = os.path.join(data_dir, f'{prefix}_{i}') | |
if not os.path.isfile(fname): | |
video_stream.download(filepath=fname, quiet=True) | |
else: | |
fname = video_stream.url | |
video_cap = cv.VideoCapture(fname) | |
self.video_caps[i] = video_cap | |
start_sec = hms_to_sec(video_ivals[i][0]) | |
start_frame = ceil(video_cap.get(cv.CAP_PROP_FPS)*start_sec) | |
end_sec = hms_to_sec(video_ivals[i][1]) | |
end_frame = ceil(video_cap.get(cv.CAP_PROP_FPS)*end_sec) | |
self.video_start_frames[i] = start_frame | |
self.video_end_frames[i] = end_frame | |
video_lenght = end_frame - start_frame + 1 | |
self.video_sample_sizes[i] = ceil(video_lenght - 1) // (self.sample_step*self.skip_step) | |
self.dataset_size += self.video_sample_sizes[i] | |
if train_split is not None: | |
assert train_split <= self.dataset_size, 'Invalid train_split argument!' | |
self.dataset_size = train_split | |
def __len__(self) -> int: | |
""" | |
Returns the dataset length. | |
""" | |
return self.dataset_size | |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Returns a tuple of input and traget frame sequences at the given sample index. | |
""" | |
assert index < self.dataset_size, "Invalid index!" | |
vnum, fnum = self._get_video_frame_num(index) | |
video_cap = self.video_caps[vnum] | |
fnum_offset = random.randint(0, self.sample_step*self.skip_step-1) if RANDOMIZE else 0 | |
if fnum + fnum_offset < self.video_end_frames[vnum]: | |
fnum = fnum + fnum_offset | |
video_cap.set(cv.CAP_PROP_POS_FRAMES, fnum) | |
frame_seq = [] | |
for i in range(2*self.n_frames*self.sample_step): | |
ret = video_cap.grab() # grab | |
assert ret == True | |
if i % self.sample_step != 0: | |
continue | |
ret, frame = video_cap.retrieve() # decode | |
assert ret == True | |
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB) # np.flip(frame, axis=2) | |
frame = Image.fromarray(frame) | |
if self.transform is not None: | |
frame = self.transform(frame) | |
else: | |
frame = transforms.ToTensor()(frame) | |
frame_seq.append(frame) | |
frame_seq = torch.stack(frame_seq, dim=0) | |
input = frame_seq[:self.n_frames] | |
target = frame_seq[self.n_frames:] | |
return input, target | |
def _get_video_frame_num(self, index: int) -> Tuple[int, int]: | |
""" | |
Returns the video number and corresponding frame number within the video, | |
according to given index. | |
""" | |
N = 0 | |
vnum = 0 | |
while index >= N: | |
N += self.video_sample_sizes[vnum] | |
vnum += 1 | |
vnum -= 1 | |
sample_num = index - (N - self.video_sample_sizes[vnum]) | |
fnum = self.video_start_frames[vnum] + sample_num * self.sample_step * self.skip_step | |
return vnum, fnum |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment