Skip to content

Instantly share code, notes, and snippets.

@xkumiyu
Last active September 7, 2017 16:52
Show Gist options
  • Save xkumiyu/9baf81ac40ee43f87bce80b991c969d2 to your computer and use it in GitHub Desktop.
Save xkumiyu/9baf81ac40ee43f87bce80b991c969d2 to your computer and use it in GitHub Desktop.
Chainer Dataset of Videos
import os
import numpy as np
import six
import cv2
import chainer
def _read_video_as_array(path, dtype):
video = []
cap = cv2.VideoCapture(path)
while(cap.isOpened()):
ret, frame = cap.read()
if ret:
video.append(frame)
else:
break
cap.release()
return np.asarray(video, dtype=dtype)
class VideoDataset(chainer.dataset.DatasetMixin):
def __init__(self, paths, root='.', dtype=np.float32):
if isinstance(paths, six.string_types):
with open(paths) as paths_file:
paths = [path.strip() for path in paths_file]
self._paths = paths
self._root = root
self._dtype = dtype
def __len__(self):
return len(self._paths)
def get_example(self, i):
"""
return video shape: (ch, frame, width, height)
"""
path = os.path.join(self._root, self._paths[i])
video = _read_video_as_array(path, self._dtype)
return video.transpose(3, 0, 2, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment