Skip to content

Instantly share code, notes, and snippets.

@AniKar
Last active March 14, 2022 12:04
Show Gist options
  • Save AniKar/056b379671b1d88db832d2508a28dfdb to your computer and use it in GitHub Desktop.
Save AniKar/056b379671b1d88db832d2508a28dfdb to your computer and use it in GitHub Desktop.
Video frame sequence DataLoader in Pytorch.
import torch
import random
import numpy as np
from math import ceil
import os
import os.path
import pafy
import cv2 as cv
class VideoFrameDataset(data.Dataset):
"""
Custom Dataset class for on-demand loading of sequences of frames from given videos.
"""
def __init__(self, video_urls: Dict[int, str], video_ivals: Dict[int, str], data_dir: str,
download: bool=True, train_split: Optional[int]=None, train: bool=True,
transform: Optional[transforms.Compose]=None, sample_step: int=2,
skip_step: int=3, n_frames: int=3):
"""
param video_urls: dictinary of the dataset video urls.
param video_ivals: dictionary of valid intervals per video.
param data_dir: local directory to store downloaded video files.
param download: specifies whether to download the videos to local dir or not.
param train_split: number of samples to be used for training.
param train: train or test dataset.
param transform: transformation to be applied to each frame.
param sample_step: step between frames at each sample.
param skip_step: number of frames to skip between samples.
param n_frames: number of input/target frames in each sample.
"""
self.n_frames = n_frames
self.sample_step = sample_step
self.skip_step = skip_step
self.transform = transform
self.video_caps = dict()
self.video_sample_sizes = dict()
self.video_start_frames = dict()
self.video_end_frames = dict()
self.dataset_size = 0
prefix = 'train' if train else 'test'
clip_length = n_frames*2
for i, url in video_urls.items():
video = pafy.new(url)
video_stream = video.getbest(preftype='mp4')
if download:
fname = os.path.join(data_dir, f'{prefix}_{i}')
if not os.path.isfile(fname):
video_stream.download(filepath=fname, quiet=True)
else:
fname = video_stream.url
video_cap = cv.VideoCapture(fname)
self.video_caps[i] = video_cap
start_sec = hms_to_sec(video_ivals[i][0])
start_frame = ceil(video_cap.get(cv.CAP_PROP_FPS)*start_sec)
end_sec = hms_to_sec(video_ivals[i][1])
end_frame = ceil(video_cap.get(cv.CAP_PROP_FPS)*end_sec)
self.video_start_frames[i] = start_frame
self.video_end_frames[i] = end_frame
video_lenght = end_frame - start_frame + 1
self.video_sample_sizes[i] = ceil(video_lenght - 1) // (self.sample_step*self.skip_step)
self.dataset_size += self.video_sample_sizes[i]
if train_split is not None:
assert train_split <= self.dataset_size, 'Invalid train_split argument!'
self.dataset_size = train_split
def __len__(self) -> int:
"""
Returns the dataset length.
"""
return self.dataset_size
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple of input and traget frame sequences at the given sample index.
"""
assert index < self.dataset_size, "Invalid index!"
vnum, fnum = self._get_video_frame_num(index)
video_cap = self.video_caps[vnum]
fnum_offset = random.randint(0, self.sample_step*self.skip_step-1) if RANDOMIZE else 0
if fnum + fnum_offset < self.video_end_frames[vnum]:
fnum = fnum + fnum_offset
video_cap.set(cv.CAP_PROP_POS_FRAMES, fnum)
frame_seq = []
for i in range(2*self.n_frames*self.sample_step):
ret = video_cap.grab() # grab
assert ret == True
if i % self.sample_step != 0:
continue
ret, frame = video_cap.retrieve() # decode
assert ret == True
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB) # np.flip(frame, axis=2)
frame = Image.fromarray(frame)
if self.transform is not None:
frame = self.transform(frame)
else:
frame = transforms.ToTensor()(frame)
frame_seq.append(frame)
frame_seq = torch.stack(frame_seq, dim=0)
input = frame_seq[:self.n_frames]
target = frame_seq[self.n_frames:]
return input, target
def _get_video_frame_num(self, index: int) -> Tuple[int, int]:
"""
Returns the video number and corresponding frame number within the video,
according to given index.
"""
N = 0
vnum = 0
while index >= N:
N += self.video_sample_sizes[vnum]
vnum += 1
vnum -= 1
sample_num = index - (N - self.video_sample_sizes[vnum])
fnum = self.video_start_frames[vnum] + sample_num * self.sample_step * self.skip_step
return vnum, fnum
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment