Skip to content

Instantly share code, notes, and snippets.

@ericjang
Last active June 28, 2021 00:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericjang/39b95c964b97541f121905f751f64ee8 to your computer and use it in GitHub Desktop.
Save ericjang/39b95c964b97541f121905f751f64ee8 to your computer and use it in GitHub Desktop.
load images at time t and image at t+3. alternative to just using two augmented views of the same image.
"""
To generate dataset, make folder with each video's frames extracted into subdir. the dino codebase has code to do this easily.
videos = glob.glob(os.path.join(SOURCE_DIR,"*.mp4"))[:num_videos_to_extract]
for i, video in enumerate(videos):
print(i)
directory=os.path.join(OUTPUT_DIR, str(i))
if not os.path.exists(directory):
os.makedirs(directory)
_extract_frames_from_video(inp=video, out=directory)
License: WTFPL
"""
import torch
from torchvision import datasets, transforms
from PIL import Image
import utils
from main_dino import DataAugmentationDINO
class PairedVideoFramesFolder(datasets.ImageFolder):
"""Samples image and image at time t+3. If we reach end of video (e.g. last 3 frames), second image is a duplicate."""
def __init__(self, time_shift=3, **kwargs):
self.time_shift = time_shift
super(PairedVideoFramesFolder, self).__init__(**kwargs)
def __getitem__(self, index):
"""Overrides DatasetFolder.
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path_t, target_t = self.samples[index]
if index >= len(self.samples) - self.time_shift:
path_tp1, target_tp1 = path_t, target_t
else:
path_tp1, target_tp1 = self.samples[index + self.time_shift]
if target_tp1 != target_t:
# tp1 is the next video. return current frame.
sample_t = self.loader(path_t)
sample_tp1 = sample_t
else:
sample_t = self.loader(path_t)
sample_tp1 = self.loader(path_tp1)
sample = self.transform(sample_t, sample_tp1)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target_t
class DataAugmentationDINOVideo(DataAugmentationDINO):
"""Extends DINO data augmentation to handle 2 global (t+3 for teacher), N global (t for student), M local (t for student).
"""
def __call__(self, image_t, image_tp1):
crops = []
# Teacher inputs.
crops.append(self.global_transfo1(image_tp1))
crops.append(self.global_transfo2(image_tp1))
# Student inputs.
crops.append(self.global_transfo1(image_t))
crops.append(self.global_transfo2(image_t))
for _ in range(self.local_crops_number):
crops.append(self.local_transfo(image_t))
return crops
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment