Last active
December 7, 2022 22:41
-
-
Save metal3d/d5e93331b04c44aa52cbfe11db4afff2 to your computer and use it in GitHub Desktop.
Keras generator to create sequence image batches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 as cv | |
import os | |
import glob | |
import numpy as np | |
import random | |
from tensorflow import keras | |
import keras.preprocessing.image | |
from tensorflow.python.keras.utils.data_utils import Sequence | |
class VideoFrameGenerator(Sequence): | |
def __init__(self, | |
rescale=1/255., | |
nbframe:int=5, | |
classes:list=[], | |
batch_size:int=16, | |
use_frame_cache:bool=False, | |
target_shape:tuple=(224, 224), | |
shuffle:bool=True, | |
transformation:keras.preprocessing.image.ImageDataGenerator=None, | |
split:float=None, | |
nb_channel:int=3, | |
glob_pattern:str='./videos/{classname}/*.avi', | |
_validation_data:list=None): | |
""" Create a generator that return batches of frames from video | |
- nbframe: int, number of frame to return for each sequence | |
- classes: list of str, classes to infer | |
- batch_size: int, batch size for each loop | |
- use_frame_cache: bool, use frame cache (may take a lot of memory for large dataset) | |
- shape: tuple, target size of the frames | |
- shuffle: bool, randomize files | |
- transformation: ImageDataGenerator with transformations | |
- split: float, factor to split files and validation | |
- nb_channel: int, 1 or 3, to get grayscaled or RGB images | |
- glob_pattern: string, directory path with '{classname}' inside that | |
will be replaced by one of the class list | |
- _validation_data: already filled list of data, do not touch ! | |
You may use the "classes" property to retrieve the class list afterward. | |
The generator has that properties initialized: | |
- classes_count: number of classes that the generator manages | |
- files_count: number of video that the generator can provides | |
- classes: the given class list | |
- files: the full file list that the generator will use, this | |
is usefull if you want to remove some files that should not be | |
used by the generator. | |
""" | |
# should be only RGB or Grayscale | |
assert nb_channel in (1, 3) | |
# we should have classes | |
assert len(classes) > 0 | |
# shape size should be 2 | |
assert len(target_shape) == 2 | |
# split factor should be a propoer value | |
if split is not None: | |
assert split < 1.0 and split > 0.0 | |
# be sure that classes are well ordered | |
classes.sort() | |
self.rescale = rescale | |
self.classes = classes | |
self.batch_size = batch_size | |
self.nbframe = nbframe | |
self.shuffle = shuffle | |
self.target_shape = target_shape | |
self.nb_channel = nb_channel | |
self.transformation = transformation | |
self.use_frame_cache = use_frame_cache | |
self._random_trans = [] | |
self.__frame_cache = {} | |
self.files = [] | |
self.validation = [] | |
if _validation_data is not None: | |
# we only need to set files here | |
self.files = _validation_data | |
else: | |
if split is not None and split > 0.0: | |
for c in classes: | |
files = glob.glob(glob_pattern.format(classname=c)) | |
nbval = int(split * len(files)) | |
print("class %s, validation count: %d" % (c, nbval)) | |
# generate validation indexes | |
indexes = np.arange(len(files)) | |
if shuffle: | |
np.random.shuffle(indexes) | |
val = np.random.permutation(indexes)[:nbval] # get some sample | |
indexes = np.array([i for i in indexes if i not in val]) # remove validation from train | |
# and now, make the file list | |
self.files += [files[i] for i in indexes] | |
self.validation += [files[i] for i in val] | |
else: | |
for c in classes: | |
self.files += glob.glob(glob_pattern.format(classname=c)) | |
# build indexes | |
self.files_count = len(self.files) | |
self.indexes = np.arange(self.files_count) | |
self.classes_count = len(classes) | |
self.on_epoch_end() # to initialize transformations and shuffle indices | |
print("get %d classes for %d files for %s" % ( | |
self.classes_count, | |
self.files_count, | |
'train' if _validation_data is None else 'validation')) | |
def get_validation_generator(self): | |
""" Return the validation generator if you've provided split factor """ | |
return self.__class__( | |
nbframe = self.nbframe, | |
nb_channel=self.nb_channel, | |
target_shape=self.target_shape, | |
classes=self.classes, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
rescale=self.rescale, | |
_validation_data=self.validation) | |
def on_epoch_end(self): | |
#prepare transformation to avoid __getitem__ to reinitialize them | |
if self.transformation is not None: | |
self._random_trans = [] | |
for i in range(self.files_count): | |
self._random_trans.append( | |
self.transformation.get_random_transform(self.target_shape) | |
) | |
if self.shuffle: | |
np.random.shuffle(self.indexes) | |
def __len__(self): | |
return int(np.floor(self.files_count / self.batch_size)) | |
def __getitem__(self, index): | |
classes = self.classes | |
shape = self.target_shape | |
nbframe = self.nbframe | |
labels = [] | |
images = [] | |
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] | |
t = None | |
for i in indexes: | |
# prepare a transformation if provided | |
if self.transformation is not None: | |
t = self._random_trans[i] | |
# video = random.choice(files) | |
video = self.files[i] | |
cl = video.split(os.sep)[-2] | |
# create a label array and set 1 to the right column | |
label = np.zeros(len(classes)) | |
col = classes.index(cl) | |
label[col] = 1. | |
if video not in self.__frame_cache: | |
cap = cv.VideoCapture(video) | |
frames = [] | |
while True: | |
grabbed, frame = cap.read() | |
if not grabbed: | |
# end of video | |
break | |
# resize | |
frame = cv.resize(frame, shape) | |
# use RGB or Grayscale ? | |
if self.nb_channel == 3: | |
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB) | |
else: | |
frame = cv.cvtColor(frame, cv.COLOR_RGB2GRAY) | |
# to np | |
frame = keras.preprocessing.image.img_to_array(frame) * self.rescale | |
# keep frame | |
frames.append(frame) | |
# Add 2 frames to drop first and last frame | |
jump = len(frames)//(nbframe+2) | |
# get only some images | |
try: | |
frames = frames[jump::jump][:nbframe] | |
except Exception as e: | |
print(video) | |
raise e | |
# add to frame cache to not read from disk later | |
if self.use_frame_cache: | |
self.__frame_cache[video] = frames | |
else: | |
frames = self.__frame_cache[video] | |
# apply transformation | |
if t is not None: | |
frames = [self.transformation.apply_transform(frame, t) for frame in frames] | |
# add the sequence in batch | |
images.append(frames) | |
labels.append(label) | |
return np.array(images), np.array(labels) |
Author
metal3d
commented
Oct 26, 2020
via email
Hello, I made a more tweakable generator here
https://github.com/metal3d/keras-video-generators
Le jeu. 15 oct. 2020 à 18:29, Connor Eaton <notifications@github.com> a
écrit :
… ***@***.**** commented on this gist.
------------------------------
Hello, I found this searching for custom keras generators allowing me to
read and transform multiple frames at once and send them to a model as 1
input. For example I want to read, transform, and stack arrays from 10
consecutive frames and send to model as 1 input. I think this code does
that. However, I am getting this error:
x = VideoFrameGenerator(
rescale=1/255.,
nbframe=10,
classes=list(range(0,14)),
batch_size=16,
use_frame_cache=False,
target_shape=(224, 224),
shuffle=True,
transformation=None,
split=.2,
nb_channel=3,
glob_pattern='../iemocap_data/label_split_videos/{classname}/*.mp4',
_validation_data=None
)
Which prints:
class 0, validation count: 22
class 1, validation count: 3
class 2, validation count: 26
class 3, validation count: 80
class 4, validation count: 129
class 5, validation count: 60
class 6, validation count: 8
class 7, validation count: 37
class 8, validation count: 199
class 9, validation count: 160
class 10, validation count: 98
class 11, validation count: 44
class 12, validation count: 43
class 13, validation count: 21
get 14 classes for 3751 files for train
And then I run as a test:
from keras.models import Sequential
from keras.layers import Dense
# define the keras model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(14, activation='softmax'))
# compile the keras model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# fit the keras model on the dataset
model.fit_generator(x,
steps_per_epoch=100,
epochs=2)
Which results in this error:
Epoch 1/2
3117
1811
1273
721
390
370
1557
682
2381
2445
2302
1742
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-73-4191a4b175bb> in <module>
12 model.fit_generator(x,
13 steps_per_epoch=100,
---> 14 epochs=2)
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1730 use_multiprocessing=use_multiprocessing,
1731 shuffle=shuffle,
-> 1732 initial_epoch=initial_epoch)
1733
1734 @interfaces.legacy_generator_methods_support
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
183 batch_index = 0
184 while steps_done < steps_per_epoch:
--> 185 generator_output = next(output_generator)
186
187 if not hasattr(generator_output, '__len__'):
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/keras/utils/data_utils.py in get(self)
623 except Exception:
624 self.stop()
--> 625 six.reraise(*sys.exc_info())
626
627
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
701 if value.__traceback__ is not tb:
702 raise value.with_traceback(tb)
--> 703 raise value
704 finally:
705 value = None
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/keras/utils/data_utils.py in get(self)
608 try:
609 future = self.queue.get(block=True)
--> 610 inputs = future.get(timeout=30)
611 except mp.TimeoutError:
612 idx = future.idx
~/anaconda3/envs/tensorflow_2/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
655 return self._value
656 else:
--> 657 raise self._value
658
659 def _set(self, i, obj):
~/anaconda3/envs/tensorflow_2/lib/python3.7/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
119 job, i, func, args, kwds = task
120 try:
--> 121 result = (True, func(*args, **kwds))
122 except Exception as e:
123 if wrap_exception and func is not _helper_reraises_exception:
~/anaconda3/envs/tensorflow_2/lib/python3.7/site-packages/keras/utils/data_utils.py in get_index(uid, i)
404 The value at index `i`.
405 """
--> 406 return _SHARED_SEQUENCES[uid][i]
407
408
<ipython-input-65-6f42e0ddb651> in __getitem__(self, index)
170 # create a label array and set 1 to the right column
171 label = np.zeros(len(classes))
--> 172 col = classes.index(cl)
173 label[col] = 1.
174
ValueError: '10' is not in list
The ValueError: '10' int changes each time I run. Sometimes its 1, 4, 2,
etc. If you have any ideas what I am doing wrong, that would be great.
Thanks!
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/d5e93331b04c44aa52cbfe11db4afff2#gistcomment-3490744>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAYN4F27OX2ZKRPK6DC2ZDSK4PQFANCNFSM4SSHADSA>
.
--
Patrice FERLET
http://www.metal3d.org
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment