Skip to content

Instantly share code, notes, and snippets.

@kiyoon
Last active February 11, 2020 17:26
Show Gist options
  • Save kiyoon/ae84ee3736c1350b20901bfb4a60d621 to your computer and use it in GitHub Desktop.
Save kiyoon/ae84ee3736c1350b20901bfb4a60d621 to your computer and use it in GitHub Desktop.
PyTorch video loader utilising GPU (CUDA) using NVIDIA DALI > 0.18.
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin import pytorch
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--file_list', type=str, default='file_list.txt',
help='DALI file_list for VideoReader')
parser.add_argument('--frames', type=int, default = 3,
help='num frames in input sequence')
parser.add_argument('--crop_size', type=int, nargs='+', default=[224, 224],
help='[height, width] for input crop')
parser.add_argument('--batchsize', type=int, default=1,
help='per rank batch size')
args = parser.parse_args()
class VideoReaderPipeline(Pipeline):
def __init__(self, batch_size, sequence_length, num_threads, device_id, file_list, crop_size):
super(VideoReaderPipeline, self).__init__(batch_size, num_threads, device_id, seed=12)
self.reader = ops.VideoReader(device="gpu", file_list=file_list, sequence_length=sequence_length, normalized=False,
random_shuffle=True, image_type=types.RGB, dtype=types.UINT8, initial_fill=16, enable_frame_num=True)
self.crop = ops.Crop(device="gpu", crop=crop_size, output_dtype=types.FLOAT)
self.uniform = ops.Uniform(range=(0.0, 1.0))
self.coin = ops.CoinFlip(probability=0.5)
self.transpose = ops.Transpose(device="gpu", perm=[3, 0, 1, 2])
def define_graph(self):
input = self.reader(name="Reader")
crop_pos_x = self.uniform()
crop_pos_y = self.uniform()
cropped = self.crop(input[0], crop_pos_x=crop_pos_x, crop_pos_y=crop_pos_y)
is_flipped = self.coin()
flipped = self.flip(cropped, horizontal=is_flipped)
output = self.transpose(flipped)
# Change what you want from the dataloader.
# input[1]: label, input[2]: starting frame number indexed from zero
return output, input[1], input[2], crop_pos_x, crop_pos_y, is_flipped
class DALILoader():
def __init__(self, batch_size, file_list, sequence_length, crop_size):
self.pipeline = VideoReaderPipeline(batch_size=batch_size,
sequence_length=sequence_length,
num_threads=2,
device_id=0,
file_list=file_list,
crop_size=crop_size)
self.pipeline.build()
self.epoch_size = self.pipeline.epoch_size("Reader")
self.dali_iterator = pytorch.DALIGenericIterator(self.pipeline,
["data", "label", "frame_num", "crop_pos_x", "crop_pos_y"],
self.epoch_size,
auto_reset=True)
def __len__(self):
return int(self.epoch_size)
def __iter__(self):
return self.dali_iterator.__iter__()
def __next__(self):
return self.dali_iterator.__next__()
if __name__ == "__main__":
loader = DALILoader(args.batchsize,
args.file_list,
args.frames,
args.crop_size)
batches = len(loader)
batch = next(loader)
print(batch[0]['data'].shape)
print(batch[0]['label'])
print(batch[0]['frame_num'])
print(batch[0]['crop_pos_x'])
print(batch[0]['crop_pos_y'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment